diff --git a/dflash/CMakeLists.txt b/dflash/CMakeLists.txt index 2bdbb219..ad6d4616 100644 --- a/dflash/CMakeLists.txt +++ b/dflash/CMakeLists.txt @@ -248,6 +248,11 @@ add_library(dflash_common STATIC src/qwen35/qwen35_layer_split_dflash_target.cpp src/qwen35/layer_split_daemon_loop.cpp src/qwen35/qwen35_daemon.cpp + src/qwen35/qwen35_mtp.cpp + src/qwen35/qwen35_mtp_graph.cpp + src/qwen35/qwen35_mtp_loader.cpp + src/common/mtp_chain_runner.cpp + src/common/mtp_orchestrator.cpp src/common/sampler.cpp src/common/daemon_loop.cpp src/common/gguf_inspect.cpp @@ -521,6 +526,13 @@ if(DFLASH27B_TESTS) target_include_directories(test_kv_quant PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS}) target_link_libraries(test_kv_quant PRIVATE dflash_common) endif() + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_common_mtp_orchestrator.cpp") + add_executable(test_common_mtp_orchestrator test/test_common_mtp_orchestrator.cpp) + target_include_directories(test_common_mtp_orchestrator PRIVATE + ${DFLASH27B_SRC_INCLUDE_DIRS} + ${CMAKE_CURRENT_SOURCE_DIR}/deps/llama.cpp/ggml/include) + target_link_libraries(test_common_mtp_orchestrator PRIVATE dflash_common ggml-base) + endif() if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_draft_vs_reference.cpp") add_executable(test_draft_vs_reference test/test_draft_vs_reference.cpp) target_link_libraries(test_draft_vs_reference PRIVATE dflash_common) diff --git a/dflash/scripts/server.py b/dflash/scripts/server.py index 01bd8196..5146c509 100644 --- a/dflash/scripts/server.py +++ b/dflash/scripts/server.py @@ -793,7 +793,11 @@ def build_app(target: Path, draft: Path | None, bin_path: Path, budget: int, max extra_daemon_args: list[str] | None = None, lazy_draft: bool = False, verbose_daemon: bool = False, - force_no_thinking: bool = False) -> FastAPI: + force_no_thinking: bool = False, + mtp_gguf: Path | None = None, + mtp_gamma: int = 3, + mtp_draft_source: str = "chain", + mtp_draft_topk: int = 1) -> FastAPI: import asyncio if _extra_daemon_has_target_sharding(extra_daemon_args): if prefix_cache_slots > 0 or prefill_cache_slots > 0: @@ -850,6 +854,19 @@ async def _openai_compat_error_handler(_request: Request, exc: OpenAICompatError cmd = [bin_abs, str(target), "--daemon", f"--max-ctx={max_ctx}", f"--stream-fd={stream_fd_val}"] + elif mtp_gguf is not None: + # MTP mode: no --draft (MTP head lives inside target or mtp_gguf), + # no DFlash flags. Daemon dispatches to MTP code path via --mtp-gguf. + cmd = [bin_abs, str(target), "--daemon", + f"--max-ctx={max_ctx}", + f"--stream-fd={stream_fd_val}", + f"--mtp-gguf={mtp_gguf}", + f"--gamma={mtp_gamma}", + "--draft-source", mtp_draft_source] + if mtp_draft_source == "mtp_topk": + cmd.append(f"--draft-topk={mtp_draft_topk}") + if extra_daemon_args: + cmd.extend(extra_daemon_args) else: if draft is None: raise SystemExit("qwen35 arch requires --draft ") @@ -2858,6 +2875,20 @@ def main(): help="Server-level guard: prevent any request from enabling thinking mode " "via chat_template_kwargs. Useful on hardware (e.g. gfx1151/Strix Halo) " "where thinking chains consume n_gen budget without benefit.") + # ── MTP (Multi-Token Prediction) speculator ────────────────────────────── + # When --mtp-gguf is set, the daemon runs MTP-head speculation instead of + # DFlash+DDTree. --draft is ignored (the MTP head is in the same GGUF as + # target, or a separate fused GGUF). Prefix-cache slots are auto-disabled + # in MTP mode because RESTORE does not snapshot MTP head KV yet. + ap.add_argument("--mtp-gguf", type=Path, default=None, + help="Path to MTP-fused GGUF. When set, daemon runs MTP " + "speculation; --draft and DFlash flags are ignored.") + ap.add_argument("--mtp-gamma", type=int, default=3, + help="MTP chain depth (default 3; recommended D=3 per matrix bench)") + ap.add_argument("--mtp-draft-source", choices=["chain", "mtp_topk"], default="chain", + help="MTP draft generation strategy (default chain)") + ap.add_argument("--mtp-draft-topk", type=int, default=1, + help="Top-K for mtp_topk draft source (default 1, ignored for chain)") add_cli_flags(ap) args = ap.parse_args() prefill_cfg = config_from_args(args) @@ -2906,6 +2937,17 @@ def main(): # through the laguna daemon now, so --prefill-compression and # --prefix-cache-slots behave the same as on the qwen35 path. draft = None + elif args.mtp_gguf is not None: + # MTP mode: --draft is ignored; MTP head lives in the target (or in --mtp-gguf + # if separate). Force prefix/prefill cache off — RESTORE doesn't snapshot + # MTP head KV yet (planned for a follow-up PR). + if not args.mtp_gguf.is_file(): + raise SystemExit(f"--mtp-gguf not found at {args.mtp_gguf}") + draft = None + if args.prefix_cache_slots > 0 or args.prefill_cache_slots > 0: + print(" [cfg] MTP mode: disabling prefix/prefill cache (MTP head KV snapshot not implemented)") + args.prefix_cache_slots = 0 + args.prefill_cache_slots = 0 else: draft = resolve_draft(args.draft) if args.draft.is_dir() else args.draft if not draft.is_file(): @@ -2938,7 +2980,11 @@ def main(): extra_daemon_args=placement.daemon_args or None, lazy_draft=args.lazy_draft, verbose_daemon=args.verbose_daemon, - force_no_thinking=args.no_thinking) + force_no_thinking=args.no_thinking, + mtp_gguf=args.mtp_gguf, + mtp_gamma=args.mtp_gamma, + mtp_draft_source=args.mtp_draft_source, + mtp_draft_topk=args.mtp_draft_topk) import uvicorn logging.basicConfig( diff --git a/dflash/src/common/attn_masks.h b/dflash/src/common/attn_masks.h index ec25ca56..9b11bc8b 100644 --- a/dflash/src/common/attn_masks.h +++ b/dflash/src/common/attn_masks.h @@ -51,13 +51,24 @@ inline void build_causal_mask(std::vector & out, // Build an ancestor-only attention mask for DDTree tree-structured verify. // Each query position i can attend to its ancestors in the tree (including // itself) plus all past KV positions. +// +// kv_pad_override: when nonzero, pin the kv (column) stride to this value +// instead of the helper's natural `align_up(past_length + N - win_start, +// kq_stride_pad)`. Needed when the consumer tensor was allocated with a +// fixed kv extent (e.g. build_target_step_tree sizes sg.attn_mask at +// align_up(cache.max_ctx + N, kq_stride_pad)) and the helper-computed +// stride would not match the tensor's actual row pitch. Default 0 keeps +// existing behavior. inline void build_tree_mask(const DDTree & tree, int past_length, std::vector & out_mask, int kq_stride_pad, - int win_start = 0) { + int win_start = 0, + int kv_pad_override = 0) { const int N = 1 + tree.n_nodes; const int win_len = past_length + N - win_start; - const int kv_pad = align_up(win_len, kq_stride_pad); + const int kv_pad = kv_pad_override > 0 + ? align_up(kv_pad_override, kq_stride_pad) + : align_up(win_len, kq_stride_pad); const int q_pad = align_up(N, KQ_MASK_PAD); out_mask.assign((size_t)kv_pad * q_pad, F16_NEG_INF); for (int q = 0; q < N; q++) { diff --git a/dflash/src/common/backend_factory.cpp b/dflash/src/common/backend_factory.cpp index 01804569..fd7ad5b3 100644 --- a/dflash/src/common/backend_factory.cpp +++ b/dflash/src/common/backend_factory.cpp @@ -8,6 +8,9 @@ #include "qwen3_backend.h" #include "gemma4_backend.h" +#include "gguf.h" + +#include #include namespace dflash::common { @@ -17,6 +20,26 @@ std::string detect_arch(const char * model_path) { return info.arch; } +bool gguf_contains_mtp_tensors(const std::string & path) { + gguf_init_params gp{}; + gp.no_alloc = true; + gp.ctx = nullptr; + gguf_context * gguf = gguf_init_from_file(path.c_str(), gp); + if (!gguf) return false; + + // MTP-capable GGUF files carry `qwen35.nextn_predict_layers` > 0. + // This is the canonical indicator used by qwen35_mtp_loader.cpp. + bool found = false; + int64_t kid = gguf_find_key(gguf, "qwen35.nextn_predict_layers"); + if (kid >= 0) { + uint32_t n = gguf_get_val_u32(gguf, kid); + found = (n > 0); + } + + gguf_free(gguf); + return found; +} + std::unique_ptr create_backend(const BackendArgs & args) { if (!args.model_path) { std::fprintf(stderr, "[backend_factory] model_path is null\n"); @@ -32,6 +55,22 @@ std::unique_ptr create_backend(const BackendArgs & args) { std::fprintf(stderr, "[backend_factory] detected arch=%s\n", arch.c_str()); + // Unset must have been resolved to None by arg parsing before reaching here. + assert(args.mtp_source != MtpSource::Unset && + "MtpSource::Unset must be resolved by arg parsing before reaching the backend factory"); + + // Resolve MtpSource::Auto before constructing the backend. + MtpSource resolved_source = args.mtp_source; + if (resolved_source == MtpSource::Auto) { + if (gguf_contains_mtp_tensors(args.model_path)) { + std::fprintf(stderr, "[backend_factory] mtp=auto: nextn_predict_layers found -> Native\n"); + resolved_source = MtpSource::Native; + } else { + std::fprintf(stderr, "[backend_factory] mtp=auto: no nextn_predict_layers -> None\n"); + resolved_source = MtpSource::None; + } + } + if (arch == "qwen35") { Qwen35Config cfg; cfg.target_path = args.model_path; @@ -50,6 +89,27 @@ std::unique_ptr create_backend(const BackendArgs & args) { cfg.ddtree_temp = args.ddtree_temp; cfg.ddtree_chain_seed = args.ddtree_chain_seed; cfg.use_feature_mirror = args.use_feature_mirror; + cfg.mtp_gamma = args.mtp_gamma; + cfg.mtp_use_topk = args.mtp_use_topk; + cfg.mtp_draft_topk = args.mtp_draft_topk; + + // Map resolved MtpSource to the paths Qwen35Backend expects. + // Qwen35Backend uses cfg_.mtp_gguf_path != nullptr as the MTP-active sentinel. + switch (resolved_source) { + case MtpSource::Native: + // MTP tensors live inside the target GGUF itself. + cfg.mtp_gguf_path = args.model_path; + break; + case MtpSource::ExternalDrafter: + cfg.mtp_gguf_path = args.mtp_gguf_path; + break; + case MtpSource::None: + case MtpSource::Auto: // fully resolved above; arm is unreachable. + case MtpSource::Unset: // guarded by assert above; arm is unreachable. + default: + cfg.mtp_gguf_path = nullptr; + break; + } auto backend = std::make_unique(cfg); if (!backend->init()) { diff --git a/dflash/src/common/backend_factory.h b/dflash/src/common/backend_factory.h index 5a419550..4cde80c2 100644 --- a/dflash/src/common/backend_factory.h +++ b/dflash/src/common/backend_factory.h @@ -18,6 +18,16 @@ namespace dflash::common { +// ─── MTP source selection ──────────────────────────────────────────────── +// Replaces the old free-form mtp_draft_source string (@howard0su #237, line 59). +enum class MtpSource { + Unset, // internal sentinel: --mtp-source not provided (never escapes arg parsing) + None, // no MTP speculator + Native, // MTP heads co-located in the target GGUF (e.g. unsloth single-file) + ExternalDrafter, // separate MTP-head GGUF supplied via mtp_gguf_path + Auto, // probe target GGUF for nextn_predict_layers; Native if found, else None +}; + // ─── Backend creation arguments ───────────────────────────────────────── // A superset of all per-arch config fields. The factory reads only those // relevant to the detected arch; unused fields are silently ignored. @@ -51,15 +61,39 @@ struct BackendArgs { float ddtree_temp = 1.0f; bool ddtree_chain_seed = true; bool use_feature_mirror = false; + + // MTP (Multi-Token Prediction) speculator — mutually exclusive with --draft. + // mtp_source drives which loading path is taken: + // Unset → internal default; --mtp-source not provided; resolved to None after + // legacy-flag inference (never reaches the backend factory as Unset). + // None → MTP disabled; mtp_gguf_path ignored. + // Native → MTP heads embedded in model_path GGUF (single-file, e.g. unsloth). + // mtp_gguf_path is left nullptr; the factory sets it to model_path. + // ExternalDrafter→ Separate MTP-head GGUF at mtp_gguf_path (required). + // Auto → factory calls gguf_contains_mtp_tensors(model_path): if true, + // resolves to Native; otherwise resolves to None. + MtpSource mtp_source = MtpSource::Unset; + const char * mtp_gguf_path = nullptr; // required only for ExternalDrafter + int mtp_gamma = 0; // 0 = MTP loaded but not active; >0 = chain depth + bool mtp_use_topk = false; // false = chain (default), true = mtp_topk strategy + int mtp_draft_topk = 1; }; // ─── Factory function ─────────────────────────────────────────────────── // Inspects model_path GGUF metadata, constructs the correct backend, and // calls init(). Returns nullptr on failure (diagnostic printed to stderr). +// When args.mtp_source == Auto, resolves to Native or None before +// constructing; the resolved value is not written back into args. std::unique_ptr create_backend(const BackendArgs & args); // Returns the detected architecture string without creating a backend. // Useful for early dispatch (e.g. printing which backend will be used). std::string detect_arch(const char * model_path); +// Returns true if the GGUF at `path` contains MTP-head tensors. +// Heuristic: presence of `qwen35.nextn_predict_layers` metadata key with +// a value > 0. Pure metadata scan — no tensor allocation, no GPU touch. +// Used by create_backend() when mtp_source == Auto. +bool gguf_contains_mtp_tensors(const std::string & path); + } // namespace dflash::common diff --git a/dflash/src/common/dflash_target.h b/dflash/src/common/dflash_target.h index 56fd4bec..a2e724dc 100644 --- a/dflash/src/common/dflash_target.h +++ b/dflash/src/common/dflash_target.h @@ -14,11 +14,36 @@ #include #include +struct ggml_backend; +typedef struct ggml_backend * ggml_backend_t; +struct ggml_tensor; + namespace dflash::common { +struct DDTree; // forward — see common/ddtree.h + struct DFlashTarget { virtual ~DFlashTarget() = default; + // Return the ggml backend used by this target's graph compute. Default + // returns nullptr; callers (e.g. Qwen3.6 MTP) that want to build CUDA + // cgraphs against the same backend should check this and fall back if + // it's null. + virtual ggml_backend_t backend() const { return nullptr; } + + // Optional: return the LM-head weight tensor on the target's backend + // (shape [n_embd, n_vocab], used by ggml_mul_mat). When non-null, the + // Qwen3.6 MTP step graph fuses `mul_mat(W, x_normed) -> argmax` into + // its own cgraph, skipping a hidden -> host -> separate-cgraph round + // trip per step. Default returns nullptr so existing targets (CPU + // stubs) keep the project_hidden_to_* fallback path. + virtual ggml_tensor * lm_head_weight() const { return nullptr; } + + // Optional: causal attention window the target's full-attn blocks use + // (kv_len - fa_window). The MTP head uses the same window so it sees + // the same active context. 0 means full causal context. + virtual int fa_window() const { return 0; } + // ── Target forward ────────────────────────────────────────────── // Run a batch of tokens through the target model. Returns the argmax @@ -33,6 +58,26 @@ struct DFlashTarget { int & last_tok, std::vector * all_argmax = nullptr) = 0; + // Tree-structured verify: run a flat DFS-ordered DDTree through the + // target with an ancestor-only attention mask. `flat_tokens[0]` is the + // tree root (= last accepted token), `flat_tokens[1..N-1]` are the + // DFS-ordered tree nodes (mirroring DDTree::token_ids). `tree.n_nodes` + // must equal flat_tokens.size() - 1. On success, `out_argmax` is the + // target's argmax at each of the N tree positions (size == N) and + // (if non-null) `out_logits` is the raw logits laid out as + // [N × vocab] floats. Returns false by default so existing targets + // that haven't wired tree-verify can be detected by callers; concrete + // targets override to plug in build_target_step_tree + the tree mask. + virtual bool verify_tree(const std::vector & flat_tokens, + const DDTree & tree, + int base_pos, + std::vector & out_argmax, + std::vector * out_logits = nullptr) { + (void)flat_tokens; (void)tree; (void)base_pos; + (void)out_argmax; (void)out_logits; + return false; + } + // ── KV state management ───────────────────────────────────────── // Snapshot KV cache state before speculative verify, so it can be @@ -42,6 +87,32 @@ struct DFlashTarget { // Restore KV cache to the last snapshot (undo speculative forward). virtual bool restore_kv() = 0; + // Rollback DeltaNet SSM/conv + full-attn KV to the accepted-path tail of + // the most recent verify_tree() call. accepted_dfs[0] must be 0 (root). + // Returns false if unsupported; callers must treat false as fatal in + // multi-iteration tree-spec loops (poisoned KV/SSM otherwise). + virtual bool restore_kv_at_dfs(const std::vector & accepted_dfs) { + (void)accepted_dfs; + return false; + } + + // Roll back DeltaNet SSM/conv + full-attn KV to slot `accept_n` of the + // most recent verify_batch chain. Requires chain capture enabled. + // Postcondition: cache cur_pos = base_pos + accept_n + 1. + // Returns false if unsupported; chain runner falls back to snapshot+recommit. + virtual bool restore_kv_at_chain(int accept_n) { + (void)accept_n; + return false; + } + + // Enable per-position DeltaNet intermediate capture in verify_batch. + // Off by default; unsafe when n_tokens > max_verify_tokens. + virtual void enable_chain_capture(bool /*on*/) {} + + // Record linear-chain topology before verify_batch so restore_kv_at_chain() + // can locate the rollback slot. Must be called before each capturable iter. + virtual void capture_topology_for_chain(int /*n_tokens*/, int /*base_pos*/) {} + // ── Token utilities ───────────────────────────────────────────── // Check if a token is end-of-sequence for this model. @@ -61,6 +132,21 @@ struct DFlashTarget { int n_tokens, std::vector & tokens_out) = 0; + // Optional: project draft hidden states through the target's lm_head + // and return the full raw logits (n_tokens * vocab floats) on host. + // Used by MTP drafters that need a top-K surface for DDTree (the + // argmax path above hides the distribution). Default returns false so + // existing targets compile unchanged; concrete targets that wire it + // up resize `logits_out` to n_tokens * vocab and return true. The + // `out_vocab` param reports the vocab dim back to the caller. + virtual bool project_hidden_to_logits(const float * /*hidden*/, + int /*n_tokens*/, + std::vector & /*logits_out*/, + int & out_vocab) { + out_vocab = 0; + return false; + } + // ── Configuration for draft model ─────────────────────────────── // Target's hidden dimension (draft model must match). @@ -72,6 +158,56 @@ struct DFlashTarget { // Which target layers to capture intermediate activations from. // The draft model's fc layer expects exactly this many feature slices. virtual const std::vector & capture_layer_ids() const = 0; + + // Return the backbone's final post-norm hidden state for the last committed + // token (hidden_size() floats, F32). Populated by verify_batch. + // Returns nullptr if not yet available (e.g. before first verify_batch). + // Default implementation returns nullptr; Qwen35DFlashTarget overrides it. + virtual const float * last_hidden() const { return nullptr; } + + // Return the full post-norm hidden sequence from the MOST RECENT + // verify_batch call: n_tokens * hidden_size() floats, F32, laid out as + // [token_0_hidden, token_1_hidden, ..., token_{n_tokens-1}_hidden]. + // *out_n_tokens is set to the number of tokens captured (matches the + // n_tokens passed to verify_batch). Default returns nullptr. + virtual const float * last_hidden_seq(int * out_n_tokens) const { + if (out_n_tokens) *out_n_tokens = 0; + return nullptr; + } + + // Return the post-norm hidden at an ABSOLUTE sequence position, if that + // position is covered by the most recent verify_batch's hidden capture. + // The Qwen3.6 MTP head needs h_{base_pos-1} for its input pair at each + // chain step, which equals last_hidden() only on the first chain step + // (right after prefill); subsequent steps need a hidden from earlier in + // the most recent verify_batch chunk. Returns nullptr if out of range. + virtual const float * hidden_at_pos(int abs_pos) const { + (void)abs_pos; + return nullptr; + } + + // Pre-final-output-norm variant of hidden_at_pos. Mirrors llama.cpp + // PR #22673's `t_h_pre_norm`. The Qwen3.6 MTP head's hnorm normalises + // h_prev internally; feeding it the post-output-norm tensor double- + // normalises and compounds per-depth rejection on D>=2 chains. Spec- + // chain callers must prefer this accessor for the outer h_prev_0 seed + // and fall back to hidden_at_pos() only if it returns nullptr (e.g. + // adapters that do not yet capture the pre-norm sequence). Default + // returns nullptr; Qwen35DFlashTarget overrides it when hidden-seq + // capture is enabled. + virtual const float * hidden_at_pos_pre_norm(int abs_pos) const { + (void)abs_pos; + return nullptr; + } + + // Enable per-position post-norm + pre-norm hidden capture during the + // next verify_batch calls. Default no-op; Qwen35DFlashTarget overrides. + virtual void enable_hidden_seq_capture(bool /*on*/) {} + + // FULL_SEQ during prefill (warm_head_kv reads per-position); LAST_ROW_ONLY + // during decode-side chain verifies. Default no-op. + enum class VerifyCaptureScope { FULL_SEQ, LAST_ROW_ONLY }; + virtual void set_hidden_capture_scope(VerifyCaptureScope /*scope*/) {} }; } // namespace dflash::common diff --git a/dflash/src/common/gguf_metadata.h b/dflash/src/common/gguf_metadata.h new file mode 100644 index 00000000..2f2b321a --- /dev/null +++ b/dflash/src/common/gguf_metadata.h @@ -0,0 +1,88 @@ +// common/gguf_metadata.h — Shared helpers for reading GGUF metadata. +// +// Provides typed "get or default" accessors and "require" accessors for +// gguf_context key-value pairs, plus an architecture validation helper. +// Use these in every loader; do not inline equivalent helpers per-arch. +// +// Include convention: #include "common/gguf_metadata.h" +// Never: ../common/gguf_metadata.h or absolute paths. + +#pragma once + +#include "gguf.h" + +#include +#include + +namespace dflash::common { + +// ── Read-with-default ───────────────────────────────────────────────────── +// Return the stored value when the key is present, default_val otherwise. + +inline uint32_t gguf_get_u32_or(struct gguf_context * gguf, const char * key, uint32_t default_val) { + int64_t id = gguf_find_key(gguf, key); + return (id >= 0) ? gguf_get_val_u32(gguf, id) : default_val; +} + +inline int32_t gguf_get_i32_or(struct gguf_context * gguf, const char * key, int32_t default_val) { + int64_t id = gguf_find_key(gguf, key); + return (id >= 0) ? gguf_get_val_i32(gguf, id) : default_val; +} + +inline float gguf_get_f32_or(struct gguf_context * gguf, const char * key, float default_val) { + int64_t id = gguf_find_key(gguf, key); + return (id >= 0) ? gguf_get_val_f32(gguf, id) : default_val; +} + +inline std::string gguf_get_str_or(struct gguf_context * gguf, const char * key, const std::string & default_val) { + int64_t id = gguf_find_key(gguf, key); + return (id >= 0) ? std::string(gguf_get_val_str(gguf, id)) : default_val; +} + +// ── Required reads ──────────────────────────────────────────────────────── +// Return false and write a descriptive error when the key is absent. + +inline bool gguf_require_u32(struct gguf_context * gguf, const char * key, + uint32_t & out, std::string & out_error) { + int64_t id = gguf_find_key(gguf, key); + if (id < 0) { + out_error = std::string("missing required GGUF key: ") + key; + return false; + } + out = gguf_get_val_u32(gguf, id); + return true; +} + +inline bool gguf_require_str(struct gguf_context * gguf, const char * key, + std::string & out, std::string & out_error) { + int64_t id = gguf_find_key(gguf, key); + if (id < 0) { + out_error = std::string("missing required GGUF key: ") + key; + return false; + } + out = gguf_get_val_str(gguf, id); + return true; +} + +// ── Architecture validation ─────────────────────────────────────────────── +// Return true when "general.architecture" equals expected_arch. +// On mismatch or absence, writes a descriptive error and returns false. + +inline bool gguf_check_architecture(struct gguf_context * gguf, + const char * expected_arch, + std::string & out_error) { + int64_t id = gguf_find_key(gguf, "general.architecture"); + if (id < 0) { + out_error = "missing required GGUF key: general.architecture"; + return false; + } + const char * arch = gguf_get_val_str(gguf, id); + if (std::string(arch) != expected_arch) { + out_error = std::string("unexpected architecture: got '") + arch + + "', expected '" + expected_arch + "'"; + return false; + } + return true; +} + +} // namespace dflash::common diff --git a/dflash/src/common/gguf_mmap.h b/dflash/src/common/gguf_mmap.h new file mode 100644 index 00000000..37416f31 --- /dev/null +++ b/dflash/src/common/gguf_mmap.h @@ -0,0 +1,219 @@ +// common/gguf_mmap.h — RAII wrapper for platform-conditional mmap of GGUF files. +// +// Encapsulates POSIX mmap / Windows MapViewOfFile behind a single interface. +// Loaders that materialize tensor data from disk must use this class instead +// of inlining equivalent platform-conditional code. +// +// Include convention: #include "common/gguf_mmap.h" +// Never: ../common/gguf_mmap.h or absolute paths. + +#pragma once + +#include +#include + +namespace dflash::common { + +class GgufMmap { +public: + GgufMmap() = default; + ~GgufMmap(); + + // Non-copyable. + GgufMmap(const GgufMmap &) = delete; + GgufMmap & operator=(const GgufMmap &) = delete; + + // Movable — transfers ownership, leaves source empty. + GgufMmap(GgufMmap &&) noexcept; + GgufMmap & operator=(GgufMmap &&) noexcept; + + // Open the file at path and mmap it read-only. + // Returns true on success. On failure, writes a human-readable description + // to out_error and leaves this object in the default (empty) state. + bool open(const std::string & path, std::string & out_error); + + const void * data() const; // nullptr when not open + size_t size() const; // 0 when not open + bool is_open() const; + + // Transfer ownership of the mmap'd region to the caller. + // After release() this object is empty (is_open() == false). + // The caller is responsible for unmapping on POSIX or UnmapViewOfFile on + // Windows, and for closing the fd on POSIX. + struct OwnedRegion { + const void * data; + size_t size; + int fd; // POSIX fd; -1 on Windows (handle already closed) + }; + OwnedRegion release(); + +private: + const void * data_ = nullptr; + size_t size_ = 0; +#if defined(_WIN32) + void * handle_ = nullptr; // HANDLE (Windows mapping object, reinterpret_cast'd) +#else + int fd_ = -1; +#endif +}; + +} // namespace dflash::common + +// ── Implementation ──────────────────────────────────────────────────────── +// Header-only: the platform-conditional code lives here rather than in a .cpp +// so that loaders can include a single file without adding a new translation unit. + +#include + +#if defined(_WIN32) +#if !defined(NOMINMAX) +#define NOMINMAX +#endif +#if !defined(WIN32_LEAN_AND_MEAN) +#define WIN32_LEAN_AND_MEAN +#endif +#include +#else +#include +#include +#include +#include +#include +#endif + +namespace dflash::common { + +inline GgufMmap::~GgufMmap() { + if (!data_) return; +#if defined(_WIN32) + UnmapViewOfFile(const_cast(data_)); + if (handle_) CloseHandle(reinterpret_cast(handle_)); +#else + ::munmap(const_cast(data_), size_); + if (fd_ >= 0) ::close(fd_); +#endif + data_ = nullptr; + size_ = 0; +} + +inline GgufMmap::GgufMmap(GgufMmap && o) noexcept + : data_(o.data_), size_(o.size_) +#if defined(_WIN32) + , handle_(o.handle_) +#else + , fd_(o.fd_) +#endif +{ + o.data_ = nullptr; + o.size_ = 0; +#if defined(_WIN32) + o.handle_ = nullptr; +#else + o.fd_ = -1; +#endif +} + +inline GgufMmap & GgufMmap::operator=(GgufMmap && o) noexcept { + if (this != &o) { + this->~GgufMmap(); + new (this) GgufMmap(std::move(o)); + } + return *this; +} + +inline bool GgufMmap::open(const std::string & path, std::string & out_error) { + // Idempotency: if already open, release prior mapping before re-opening. + // This prevents leaking the prior fd/mapping and ensures that, on failure, + // the object is left in the default empty state (not half-overwritten). + if (data_) { this->~GgufMmap(); new (this) GgufMmap(); } +#if defined(_WIN32) + const int wlen = MultiByteToWideChar(CP_UTF8, 0, path.c_str(), -1, nullptr, 0); + if (wlen <= 0) { + out_error = "GgufMmap: MultiByteToWideChar failed for " + path; + return false; + } + std::wstring wpath; + wpath.resize(wlen - 1); + MultiByteToWideChar(CP_UTF8, 0, path.c_str(), -1, wpath.data(), wlen); + + HANDLE hFile = CreateFileW(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, + nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); + if (hFile == INVALID_HANDLE_VALUE) { + out_error = "GgufMmap: CreateFileW failed for " + path; + return false; + } + LARGE_INTEGER li; + if (!GetFileSizeEx(hFile, &li)) { + out_error = "GgufMmap: GetFileSizeEx failed for " + path; + CloseHandle(hFile); + return false; + } + size_t file_size = static_cast(li.QuadPart); + + HANDLE hMapping = CreateFileMappingA(hFile, nullptr, PAGE_READONLY, 0, 0, nullptr); + CloseHandle(hFile); + if (!hMapping) { + out_error = "GgufMmap: CreateFileMappingA failed for " + path; + return false; + } + void * addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); + if (!addr) { + out_error = "GgufMmap: MapViewOfFile failed for " + path; + CloseHandle(hMapping); + return false; + } + data_ = addr; + size_ = file_size; + handle_ = reinterpret_cast(hMapping); + return true; +#else + int fd = ::open(path.c_str(), O_RDONLY); + if (fd < 0) { + out_error = std::string("GgufMmap: open failed for ") + path + ": " + strerror(errno); + return false; + } + struct stat st{}; + if (::fstat(fd, &st) < 0) { + out_error = std::string("GgufMmap: fstat failed: ") + strerror(errno); + ::close(fd); + return false; + } + size_t file_size = static_cast(st.st_size); + void * addr = ::mmap(nullptr, file_size, PROT_READ, MAP_PRIVATE, fd, 0); + if (addr == MAP_FAILED) { + out_error = std::string("GgufMmap: mmap failed: ") + strerror(errno); + ::close(fd); + return false; + } + data_ = addr; + size_ = file_size; + fd_ = fd; + return true; +#endif +} + +inline const void * GgufMmap::data() const { return data_; } +inline size_t GgufMmap::size() const { return size_; } +inline bool GgufMmap::is_open() const { return data_ != nullptr; } + +inline GgufMmap::OwnedRegion GgufMmap::release() { + OwnedRegion r{}; + r.data = data_; + r.size = size_; +#if defined(_WIN32) + // Close the mapping handle now. Per MSDN, closing the handle does not + // unmap the view — the view remains valid until UnmapViewOfFile is called. + // OwnedRegion has no handle field; caller unmaps via UnmapViewOfFile(data). + if (handle_) CloseHandle(reinterpret_cast(handle_)); + r.fd = -1; + handle_ = nullptr; +#else + r.fd = fd_; + fd_ = -1; +#endif + data_ = nullptr; + size_ = 0; + return r; +} + +} // namespace dflash::common diff --git a/dflash/src/common/model_backend.h b/dflash/src/common/model_backend.h index 504b68eb..d497df32 100644 --- a/dflash/src/common/model_backend.h +++ b/dflash/src/common/model_backend.h @@ -25,6 +25,10 @@ namespace dflash::common { // Return true to continue generation, false to abort. using TokenCallback = std::function; +// Forward-declare MTP module so the optional mtp() accessor below +// doesn't pull mtp_interface.h into every backend. +namespace mtp { struct IMtpModule; } + // ─── I/O handle passed to backend methods that need protocol output ───── struct DaemonIO { int stream_fd = -1; @@ -174,6 +178,19 @@ struct ModelBackend { // supports_dflash_spec_decode() returns true. Default returns nullptr. virtual class DFlashTarget * dflash_target() { return nullptr; } + // ── MTP speculative decode support ─────────────────────────────── + // Returns true if this backend has an MTP module bolted on (Gemma4's + // external drafter, Qwen3.6's native heads, or any future MTP impl). + // The MTP path is orthogonal to DFlash: both may be supported by the + // same backend (caller picks via request method). + virtual bool supports_mtp() const { return false; } + + // Return the IMtpModule adapter for this backend. Only valid when + // supports_mtp() returns true. Default returns nullptr. + // Forward-declared to avoid a header dependency from model_backend.h + // on mtp_interface.h — backends include both as needed. + virtual mtp::IMtpModule * mtp() { return nullptr; } + // ── Cleanup ────────────────────────────────────────────────────── // Release all resources (weights, cache, snapshots, drafter). // Called by run_daemon() before returning. diff --git a/dflash/src/common/mtp_chain_runner.cpp b/dflash/src/common/mtp_chain_runner.cpp new file mode 100644 index 00000000..1e5ea6b6 --- /dev/null +++ b/dflash/src/common/mtp_chain_runner.cpp @@ -0,0 +1,297 @@ +// mtp_chain_runner.cpp — see mtp_chain_runner.h for contract. + +#include "mtp_chain_runner.h" + +#include "dflash_target.h" + +#include +#include +#include +#include + +namespace dflash::common::mtp { + +MtpChainRunner::MtpChainRunner(IMtpModule & mtp, + DFlashTarget & target, + const SamplerCfg & sampler) + : mtp_(mtp), target_(target), sampler_cfg_(sampler) { +} + +bool MtpChainRunner::propose_drafts_(int32_t current_token, + int base_pos, + int gamma, + const float * prev_hidden, + int prev_hidden_dim, + std::vector & drafts_out, + std::vector & next_hidden_out) { + drafts_out.clear(); + drafts_out.reserve(gamma); + + if (mtp_.flavor() == MtpFlavor::NativeHeads) { + auto & native = static_cast(mtp_); + std::vector outs; + // Phase A: ask the native module for `gamma` autoregressive drafts. + // For depth==1 (today's production for any pre-Phase-A native module) + // this is byte-identical to the old step_batch path via the default + // impl in INativeMtp::step_chain. For depth>1 the module chains its + // own forward and returns up to `gamma` drafts. + if (!native.step_chain(current_token, base_pos, gamma, outs)) return false; + const int got = (int)outs.size(); + const int take = std::min(gamma, got); + for (int i = 0; i < take; i++) drafts_out.push_back(outs[i].draft_token); + // NativeHeads does not produce h_prev — leave next_hidden_out empty. + next_hidden_out.clear(); + return true; + } + + // ExternalDrafter: γ serial step() calls, threading h_prev. + auto & ext = static_cast(mtp_); + const int H = mtp_.hidden_size(); + + std::vector running_hidden; + if (prev_hidden && prev_hidden_dim == H) { + running_hidden.assign(prev_hidden, prev_hidden + H); + } else { + // Caller did not supply h_prev — module must be in a state where + // step() can handle a null prev_hidden (e.g. first iter of a chain). + } + + int32_t cur = current_token; + for (int g = 0; g < gamma; g++) { + StepInput in; + in.current_token = cur; + in.base_pos = base_pos; // committed base; module uses gamma_index for offset + in.gamma_index = g; + in.prev_hidden = running_hidden.empty() ? nullptr : running_hidden.data(); + in.prev_hidden_dim = (int)running_hidden.size(); + + StepOutput out; + if (!ext.step(in, out)) return false; + + drafts_out.push_back(out.draft_token); + cur = out.draft_token; + + // Thread next_hidden into the following iter. If the module did + // not produce a hidden (e.g. final step on a small chain), keep + // running_hidden as-is so the next iter's caller can still pass + // the most recent value. + if (!out.next_hidden.empty()) { + running_hidden = std::move(out.next_hidden); + } + } + + next_hidden_out = std::move(running_hidden); + return true; +} + +GenerateResult MtpChainRunner::run(const GenerateRequest & req, + const DaemonIO & io, + int32_t last_prefill_token, + int committed_pos, + int gamma) { + GenerateResult result; + const int n_gen = req.n_gen; + if (n_gen <= 0) { result.ok = true; return result; } + + // Clamp γ to the module's stated ceiling. + const int gamma_max = std::max(1, mtp_.max_gamma()); + if (gamma > gamma_max) { + std::fprintf(stderr, + "[mtp_chain_runner] γ=%d > module max_gamma=%d; clamping.\n", + gamma, gamma_max); + gamma = gamma_max; + } + if (gamma < 1) gamma = 1; + + auto t0 = std::chrono::steady_clock::now(); + result.tokens.reserve(n_gen); + + int32_t cur_tok = last_prefill_token; + int base_pos = committed_pos; + + // Enable per-position DeltaNet intermediate capture for the duration of + // this run, so verify_batch records the per-slot SSM state and conv input + // window we need for restore_kv_at_chain() on partial-accept iters. + // Safe because chain verify candidates are bounded by g_actual+1 <= + // max_gamma+1 <= max_verify_tokens. RAII: turned off on every exit path. + struct ChainCaptureGuard { + DFlashTarget & t; + ~ChainCaptureGuard() { t.enable_chain_capture(false); } + }; + target_.enable_chain_capture(true); + ChainCaptureGuard guard{target_}; + + // ExternalDrafter optionally provides h_prev across iters. NativeHeads + // does not — running_hidden stays empty for that flavor. + std::vector running_hidden; + + bool hit_eos = false; + + while ((int)result.tokens.size() < n_gen && !hit_eos) { + const int remaining = n_gen - (int)result.tokens.size(); + const int g_iter = std::min(gamma, remaining); + + // ── Propose γ drafts ─────────────────────────────────────────── + std::vector drafts; + std::vector next_hidden; + if (!propose_drafts_(cur_tok, base_pos, g_iter, + running_hidden.empty() ? nullptr : running_hidden.data(), + (int)running_hidden.size(), + drafts, next_hidden)) { + result.ok = false; + result.error = "mtp.propose"; + return result; + } + const int g_actual = (int)drafts.size(); + stats_.total_proposed += g_actual; + + // ── Verify on target ─────────────────────────────────────────── + // Candidate sequence: [cur_tok, drafts[0..g_actual-1]] + // After verify, all_argmax[i] = target's argmax AFTER seeing + // candidate[i] at base_pos+i. + std::vector candidate; + candidate.reserve(1 + g_actual); + candidate.push_back(cur_tok); + for (auto d : drafts) candidate.push_back(d); + + if (!target_.snapshot_kv()) { + result.ok = false; + result.error = "snapshot_kv"; + return result; + } + + // Caller-owned topology for restore_kv_at_chain (bug #2). + target_.capture_topology_for_chain((int)candidate.size(), base_pos); + + int last_argmax = -1; + std::vector all_argmax; + if (!target_.verify_batch(candidate, base_pos, last_argmax, &all_argmax)) { + target_.restore_kv(); + result.ok = false; + result.error = "verify_batch"; + return result; + } + if ((int)all_argmax.size() < (int)candidate.size()) { + target_.restore_kv(); + result.ok = false; + result.error = "verify_batch_short"; + return result; + } + + // ── Accept longest matching prefix ───────────────────────────── + int accept_n = 0; + for (int i = 0; i < g_actual; i++) { + if (drafts[i] == all_argmax[i]) accept_n++; + else break; + } + + // Total tokens this iter = accept_n + 1 (the bonus from target's + // argmax at the divergence point or after the last accepted draft). + const int total_this_iter = accept_n + 1; + + // ── KV reconciliation ────────────────────────────────────────── + // Three paths converge on the SAME post-iter invariant: + // base_pos advances by total_this_iter = accept_n + 1 + // cur_tok = bonus (= all_argmax[accept_n]) + // and the bonus's KV is written by the NEXT iter's verify_batch. + // + // 1. accept-all (accept_n == g_actual): verify_batch wrote g+1 slots; + // we treat the last (bonus) slot as uncommitted and let next iter + // overwrite it. + // 2. fast rollback (restore_kv_at_chain succeeds): rolls cache.cur_pos + // to base_pos + accept_n + 1, leaving the bonus slot unwritten. + // 3. recommit (fast path declined): snapshot+restore, then recommit + // only [cur, accepted...] (accept_n+1 slots, NO bonus). Bonus is + // threaded via cur_tok like the other paths. Advancing by + // accept_n+2 here would skip a position every recommit iter and + // diverge from AR — see test_recommit_byte_identical_to_ar. + if (accept_n < g_actual) { + if (!target_.restore_kv_at_chain(accept_n)) { + // Slow path: snapshot rollback + commit ONLY [cur, accepted...] + // (accept_n+1 slots). Bonus stays uncommitted, threaded via + // cur_tok like the fast path. + if (!target_.restore_kv()) { + result.ok = false; + result.error = "restore_kv"; + return result; + } + std::vector commit_seq; + commit_seq.reserve((size_t)accept_n + 1); + commit_seq.push_back(cur_tok); + for (int i = 0; i < accept_n; i++) commit_seq.push_back(drafts[i]); + int discard = -1; + if (!target_.verify_batch(commit_seq, base_pos, discard, nullptr)) { + result.ok = false; + result.error = "recommit"; + return result; + } + } + } + + // ── Emit accepted prefix + bonus, capped at n_gen ────────────── + // emit_cap is the absolute ceiling on tokens that may be written + // this iter; the runner advances KV by total_this_iter regardless + // (we already verified/recommitted that many positions) so KV state + // stays consistent even when emission is truncated. + const int emit_cap = std::min(total_this_iter, + n_gen - (int)result.tokens.size()); + int emitted = 0; + for (int i = 0; i < accept_n && emitted < emit_cap; i++) { + result.tokens.push_back(drafts[i]); + if (req.stream) io.emit(drafts[i]); + emitted++; + if (target_.is_eos(drafts[i])) { hit_eos = true; break; } + } + if (!hit_eos && emitted < emit_cap) { + const int32_t bonus = all_argmax[accept_n]; + result.tokens.push_back(bonus); + if (req.stream) io.emit(bonus); + emitted++; + if (target_.is_eos(bonus)) hit_eos = true; + cur_tok = bonus; + } else { + cur_tok = result.tokens.empty() ? cur_tok : result.tokens.back(); + } + + // All paths share: base_pos += total_this_iter and cur_tok = bonus. + base_pos += total_this_iter; + + stats_.total_iters += 1; + stats_.total_accepted += accept_n; + stats_.total_emitted += emitted; + + // ── Thread h_prev for next iteration ────────────────────────── + // ExternalDrafter: on partial accept (accept_n < g_actual) the + // drafter-internal next_hidden was produced after the REJECTED + // suffix, not at the committed boundary. Use the target-captured + // hidden at row accept_n instead (cubic review finding P1). + // On full accept next_hidden is already correct (last row == last + // accepted position), so the same set_capture_row path is safe. + // NativeHeads: next_hidden is always empty; this block is skipped. + if (mtp_.flavor() == MtpFlavor::ExternalDrafter && !next_hidden.empty()) { + auto & ext = static_cast(mtp_); + const int H = mtp_.hidden_size(); + ext.set_capture_row(accept_n); + std::vector boundary_hidden((size_t)H); + if (ext.consume_captured_hidden(boundary_hidden.data(), H)) { + running_hidden = std::move(boundary_hidden); + } else { + // consume failed (e.g. stub); fall back to drafter output + running_hidden = std::move(next_hidden); + } + } else { + running_hidden = std::move(next_hidden); + } + } + + if (hit_eos) stats_.eos_hits++; + + if (req.stream) io.emit(-1); + + auto t1 = std::chrono::steady_clock::now(); + result.decode_s = std::chrono::duration(t1 - t0).count(); + result.ok = true; + return result; +} + +} // namespace dflash::common::mtp diff --git a/dflash/src/common/mtp_chain_runner.h b/dflash/src/common/mtp_chain_runner.h new file mode 100644 index 00000000..b3a3305e --- /dev/null +++ b/dflash/src/common/mtp_chain_runner.h @@ -0,0 +1,87 @@ +// mtp_chain_runner.h — Generic γ-loop for MTP speculative decoding. +// +// Drives an IMtpModule (ExternalDrafter or NativeHeads) through γ +// speculative steps per iteration, verifies the resulting candidate +// sequence on the target via DFlashTarget::verify_batch, accepts the +// longest matching prefix + 1 bonus token, and rolls back the target's +// KV cache on any reject. Identical verify path for both flavors; +// only the propose step dispatches on flavor(). +// +// Follows the same γ-loop pattern as common/dflash_spec_decode.h. + +#pragma once + +#include "model_backend.h" // GenerateRequest, GenerateResult, DaemonIO +#include "mtp_interface.h" +#include "sampler.h" + +#include + +namespace dflash::common { + +struct DFlashTarget; // forward — see common/dflash_target.h + +namespace mtp { + +// Per-iteration telemetry. Aggregated into MtpChainRunner::stats() so +// callers (test, daemon) can report acceptance / chain depth. +struct MtpChainStats { + int total_iters = 0; // chain iterations executed + int total_proposed = 0; // draft tokens proposed (Σ γ across iters) + int total_accepted = 0; // draft tokens accepted (Σ accept_n) + int total_emitted = 0; // tokens written to out_tokens (accepted + bonus) + int eos_hits = 0; +}; + +class MtpChainRunner { +public: + MtpChainRunner(IMtpModule & mtp, + DFlashTarget & target, + const SamplerCfg & sampler); + + // Run the MTP γ-loop over the prompt in `req`, writing decoded + // tokens into the result. Caller is responsible for prefill — this + // runner assumes target.verify_batch and DFlashTarget snapshot_kv + // /restore_kv are in a state where the last prefill token has been + // committed and `committed_pos` (passed as `req.snap_pos` when set, + // else the prefill length) points just past it. + // + // `gamma` is the chain length. `gamma > mtp.max_gamma()` is clamped + // and a stderr warning is emitted once per run. + // + // The runner does not own prefill — that's the backend's job. It + // does own: propose, verify, accept/rollback, sample-on-tie, + // emit (stream), EOS detection. + GenerateResult run(const GenerateRequest & req, + const DaemonIO & io, + int32_t last_prefill_token, + int committed_pos, + int gamma); + + const MtpChainStats & stats() const { return stats_; } + +private: + IMtpModule & mtp_; + DFlashTarget & target_; + SamplerCfg sampler_cfg_; + MtpChainStats stats_; + + // Propose γ draft tokens from the current position. Dispatches on + // mtp_.flavor(); ExternalDrafter threads prev_hidden through γ + // serial step() calls, NativeHeads issues one step_batch(). + // Returns false on module failure (callers abort the run). + // + // `prev_hidden` is the host-side h_prev captured from the previous + // commit's post-norm hidden (ExternalDrafter only; ignored for + // NativeHeads). `prev_hidden_dim` must equal mtp_.hidden_size(). + bool propose_drafts_(int32_t current_token, + int base_pos, + int gamma, + const float * prev_hidden, + int prev_hidden_dim, + std::vector & drafts_out, + std::vector & next_hidden_out); +}; + +} // namespace mtp +} // namespace dflash::common diff --git a/dflash/src/common/mtp_interface.h b/dflash/src/common/mtp_interface.h new file mode 100644 index 00000000..f0b04517 --- /dev/null +++ b/dflash/src/common/mtp_interface.h @@ -0,0 +1,215 @@ +// mtp_interface.h — Generic MTP (Multi-Token Prediction) module interface. +// +// Hosts multiple MTP designs under one outer abstraction: +// +// - ExternalDrafter (Gemma4): separate drafter weights with explicit +// h_prev chain and cross-attention into the target's KV donor layers. +// - NativeHeads (Qwen3.6): MTP heads built into the backbone with +// embedding + LM-head weight sharing; no explicit chain. +// +// A common base IMtpModule expresses what every MTP truly has; per-flavor +// mixins (IExternalDrafterMtp, INativeMtp) carry the flavor-specific +// surface. The γ-loop + verify + accept/rollback live once in +// MtpChainRunner (see mtp_chain_runner.h) and dispatch on `flavor()`. +// +// This file is peer to dflash_target.h and follows the same pattern: +// the target's existing DFlashTarget adapter provides everything the +// chain runner needs (verify_batch, snapshot_kv, restore_kv, +// embed_tokens, is_eos, hidden_size); MTP modules do not duplicate it. + +#pragma once + +#include +#include + +namespace dflash::common { + +struct DFlashTarget; // forward — see common/dflash_target.h + +namespace mtp { + +// ── Flavor tag ────────────────────────────────────────────────────────── +// MtpChainRunner dispatches on this; concrete classes set it via the +// matching mixin (IExternalDrafterMtp / INativeMtp). +enum class MtpFlavor { + ExternalDrafter, // Gemma4-style: separate drafter, h_prev chain + NativeHeads, // Qwen3.6-style: MTP heads in the backbone +}; + +// ── Per-step value types ──────────────────────────────────────────────── +// +// StepInput / StepOutput describe one γ-step of speculation. They are +// flavor-agnostic at the type level; flavor-specific fields are nullable +// (prev_hidden is consumed only by ExternalDrafter; NativeHeads ignores +// it and uses step_batch() to emit all γ tokens at once). + +struct StepInput { + int32_t current_token = -1; // last accepted token id + int base_pos = 0; // committed target position + int gamma_index = 0; // 0..gamma-1 within the chain + const float * prev_hidden = nullptr; // ExternalDrafter only; null otherwise + int prev_hidden_dim = 0; // length of prev_hidden when non-null +}; + +struct StepOutput { + int32_t draft_token = -1; + float draft_logit = 0.0f; + std::vector next_hidden; // ExternalDrafter writes h_post; empty for NativeHeads + + // Optional top-K logprobs surface for tree-structured drafting (DDTree). + // Empty when the module is configured for argmax-only drafting (the + // default). When populated by NativeHeads with K>1, both vectors are + // length K, sorted DESCENDING by logprob (rank 0 == argmax). For + // multi-head emission the runner builds a [L * K] layout by stacking + // the per-head vectors in order; each StepOutput holds the K entries + // for its own depth. + std::vector topk_logprobs; + std::vector topk_ids; +}; + +// ── Common base ───────────────────────────────────────────────────────── +// +// Methods every MTP implementation truly has. LSP-safe: callers that +// only need flavor-agnostic lifecycle (attach / reset / shutdown) work +// against this base; flavor-specific entry points live on the mixins. + +struct IMtpModule { + virtual ~IMtpModule() = default; + + // Identifies which mixin (and therefore which entry point) this + // module implements. Set via the matching mixin's `final` override; + // do not override directly in concrete classes. + virtual MtpFlavor flavor() const = 0; + + // Architectural ceiling on chain length. Qwen3.6 typically returns 2; + // Gemma4 returns the drafter's trained chain depth. + virtual int max_gamma() const = 0; + + // Requested operating γ for this module. Set once via set_effective_gamma + // before the first decode. Orchestrator + chain runner read this — no + // parallel storage anywhere else. Bug class blocked by construction: + // gamma cannot disagree between caller and module. + virtual int effective_gamma() const = 0; + virtual void set_effective_gamma(int gamma) = 0; + + // Backbone hidden size the module operates against. Must match the + // target's DFlashTarget::hidden_size() exactly; chain runner asserts. + virtual int hidden_size() const = 0; + + // Bind the module to its target (KV cache, embedding, EOS predicate, + // LM-head projection). Called once before the first step; returns + // false if shapes / arches are incompatible. + virtual bool attach(DFlashTarget * target) = 0; + + // Clear any per-chain state (e.g. h_prev ring head, partial-accept + // bookkeeping). Called by the runner between user requests. + virtual void reset_chain() = 0; + + // Release all device + host resources. Called at backend shutdown. + virtual void shutdown() = 0; + + // Seed h_prev for the first chain step (last post-norm hidden from prefill). + // Default no-op; both ExternalDrafter and NativeHeads override. + virtual void set_initial_hidden(const float * /*h_prev*/, int /*dim*/) {} +}; + +// ── ExternalDrafter mixin ─────────────────────────────────────────────── +// +// For MTP designs whose drafter is a separate model that reads the +// target's intermediate KV state and propagates an h_prev hidden through +// the γ chain (Gemma4 today; future external drafters plug in here). + +struct IExternalDrafterMtp : IMtpModule { + MtpFlavor flavor() const final { return MtpFlavor::ExternalDrafter; } + + // Single drafter step. The runner threads `prev_hidden` (the + // captured h_prev from the previous step or from the target's + // post-norm hidden after the last commit) into `StepInput`. + // On return, `StepOutput::next_hidden` carries the drafter's h_post + // for the next iteration. + virtual bool step(const StepInput & in, StepOutput & out) = 0; + + // Which target layers the drafter cross-attends to. Resolved at + // load time (e.g. Gemma4's resolve_mtp_donor_layers). The runner + // hands these to the target's DFlashTarget::verify_batch path so + // the target captures activations at exactly these layers. + virtual const std::vector & donor_layers() const = 0; + + // Configure the target-side h_prev capture buffer. + // batch_mode=false : single-row capture (γ=1 path). + // batch_mode=true : write all n_tokens rows during verify so + // partial-accept γ>1 can pick the right row + // host-side without a re-capture forward. + // `gamma_max` sizes the batch buffer. + virtual bool enable_target_hidden_capture(bool batch_mode, int gamma_max) = 0; + + // For γ>1 partial-accept: the row of the post-norm hidden tensor + // to read into prev_hidden on the next step. Default sentinel -1 + // means "last row" (matches γ=1 contract). + virtual void set_capture_row(int row) = 0; + + // Host-readable copy of the captured h_prev for the next step. + // `out` must have space for `hidden_size()` floats; `dim` is the + // caller's expected dim and is asserted to match. + virtual bool consume_captured_hidden(float * out, int dim) = 0; +}; + +// ── NativeHeads mixin ─────────────────────────────────────────────────── +// +// For MTP designs where the heads live inside the target's backbone and +// emit multiple draft tokens in one forward (Qwen3.6 today; DeepSeek-V3 +// style would fit here too). + +struct INativeMtp : IMtpModule { + MtpFlavor flavor() const final { return MtpFlavor::NativeHeads; } + + // Number of draft tokens emitted per call to step_batch(). + // Bounded by max_gamma(); typically 1–2 for Qwen3.6. + virtual int num_heads() const = 0; + + // Emit up to `num_heads()` draft tokens in a single backbone-aware + // forward. The runner calls this once per chain (no h_prev threading); + // `out` is sized by the implementation to num_heads(). + virtual bool step_batch(int32_t current_token, + int base_pos, + std::vector & out) = 0; + + // Configure per-head top-K logprob emission. Default K=1 means argmax + // only (StepOutput.topk_* stays empty — pre-existing behavior). With + // K>1, step_batch additionally fills StepOutput.topk_logprobs and + // topk_ids (length K, sorted DESCENDING) on every emitted head, which + // the DDTree builder consumes. Concrete impls override; the default + // no-op keeps fake/stub subclasses compatible with the existing ABI. + virtual void set_draft_topk(int /*k*/) {} + + // Multi-step autoregressive chain draft. Concrete implementations chain + // their own forward `chain_depth` times, feeding the head's own + // post-shared_head_norm hidden as h_prev for the next iteration. `out` + // is sized to the number of drafts actually emitted (≤ chain_depth); + // each element is populated like step_batch (draft_token / draft_logit + // and optionally topk_*). + // + // Default implementation: forward to step_batch and return ALL emitted + // drafts unchanged. This preserves the pre-Phase-A semantics for + // multi-head native modules (where step_batch emits `num_heads` drafts + // per call) — the runner sees the same draft count it did before, and + // `chain_depth` is effectively ignored for any module that hasn't + // overridden step_chain. Overriders (Qwen35MtpModule today) treat + // chain_depth as authoritative. + virtual bool step_chain(int32_t current_token, + int base_pos, + int /*chain_depth*/, + std::vector & out) { + return step_batch(current_token, base_pos, out); + } + + // Pre-warm head K/V over all prefill positions. `hiddens` is the backbone's + // per-position post-norm sequence laid out [tok0_hidden, ..., tokN_hidden]. + virtual bool warm_head_kv(const int32_t * /*prompt*/, int /*n_prompt*/, + int32_t /*prefill_next*/, const float * /*hiddens*/) { + return true; + } +}; + +} // namespace mtp +} // namespace dflash::common diff --git a/dflash/src/common/mtp_orchestrator.cpp b/dflash/src/common/mtp_orchestrator.cpp new file mode 100644 index 00000000..86aa5e56 --- /dev/null +++ b/dflash/src/common/mtp_orchestrator.cpp @@ -0,0 +1,177 @@ +#include "common/mtp_orchestrator.h" +#include "common/dflash_target.h" +#include "common/mtp_chain_runner.h" + +#include +#include +#include +#include + +namespace dflash::common { +namespace mtp { + +namespace { +constexpr int kDefaultPrefillUbatch = 512; + +int env_int(const char * name, int defv) { + if (const char * s = std::getenv(name)) { + const int v = std::atoi(s); + if (v > 0) return v; + } + return defv; +} +} + +GenerateResult warm_and_decode(ModelBackend * backend, + const GenerateRequest & req, + const DaemonIO & io_in) { + // Mirror laguna_backend.cpp:151 / gemma4_backend.cpp:172: wrap io + // with the request's token callback so MTP requests get streaming + // disconnect cancellation and per-token notifications. + const DaemonIO io = io_in.with_token_callback(req.on_token); + + GenerateResult result; + if (!backend) { + result.error = "warm_and_decode: backend pointer is null"; + return result; + } + if (!backend->supports_mtp()) { + result.error = "warm_and_decode: backend does not support MTP"; + return result; + } + if (req.prompt.empty()) { + result.error = "warm_and_decode: prompt is empty"; + return result; + } + + dflash::common::mtp::IMtpModule * module = backend->mtp(); + DFlashTarget * target = backend->dflash_target(); + if (!module || !target) { + result.error = "warm_and_decode: backend missing mtp() or dflash_target()"; + return result; + } + + const int hidden = target->hidden_size(); + const int prompt_len = (int)req.prompt.size(); + const int prefill_ubatch = env_int("DFLASH27B_PREFILL_UBATCH", kDefaultPrefillUbatch); + + // Capture state is owned by the target+MTP attachment, not the orchestrator. + // MTP's attach() already enabled+pinned FULL_SEQ; calling here would be a + // no-op and an architectural smell (orchestrator reaching into target state). + // No-op for non-MTP targets; Qwen35DFlashTarget overrides to pin capture. + target->enable_hidden_seq_capture(true); // idempotent for MTP-bound target + + std::vector all_prefill_hidden((size_t)prompt_len * hidden); + int32_t last_tok = -1; + + auto t_prefill0 = std::chrono::steady_clock::now(); + for (int start = 0; start < prompt_len;) { + const int n = std::min(prefill_ubatch, prompt_len - start); + std::vector chunk(req.prompt.begin() + start, + req.prompt.begin() + start + n); + if (!target->verify_batch(chunk, start, last_tok, nullptr)) { + target->enable_hidden_seq_capture(false); + result.error = "warm_and_decode: verify_batch failed during prefill"; + io.emit(-1); + return result; + } + int n_chunk = 0; + const float * h_seq = target->last_hidden_seq(&n_chunk); + // Invariant: capture is enabled+pinned by MTP attach, so verify_batch + // must return the full chunk. If it doesn't, fail loud rather than + // silently mangle all_prefill_hidden — clearing it (the pre-fix + // behavior) made the next chunk's memcpy write past freed memory. + if (!h_seq || n_chunk != n) { + result.error = "warm_and_decode: hidden seq capture invariant violated"; + io.emit(-1); + return result; + } + std::memcpy(all_prefill_hidden.data() + (size_t)start * hidden, + h_seq, sizeof(float) * (size_t)n * hidden); + start += n; + } + result.prefill_s = std::chrono::duration( + std::chrono::steady_clock::now() - t_prefill0).count(); + + // No scope toggle here: MTP-pinned target stays FULL_SEQ for the chain's + // whole lifetime (partial-accept iters need the COMMITTED row, not the + // last-candidate row, so LAST_ROW_ONLY would silently return null). + + if (last_tok < 0) { + result.error = "warm_and_decode: prefill produced invalid argmax"; + io.emit(-1); + return result; + } + + module->reset_chain(); + if (target->last_hidden() != nullptr) { + module->set_initial_hidden(target->last_hidden(), hidden); + } + if (module->flavor() == dflash::common::mtp::MtpFlavor::NativeHeads + && !all_prefill_hidden.empty()) { + // flavor() guarantees the concrete type; static_cast is safe. + auto * native = static_cast(module); + if (native && !native->warm_head_kv(req.prompt.data(), prompt_len, + last_tok, all_prefill_hidden.data())) { + result.error = "warm_and_decode: warm_head_kv failed"; + io.emit(-1); + return result; + } + } + + // Emit prefill token, then drive chain runner. + result.tokens.push_back(last_tok); + io.emit(last_tok); + if (target->is_eos(last_tok) || req.n_gen <= 1) { + io.emit(-1); + result.ok = true; + return result; + } + + SamplerCfg sampler = req.sampler; + GenerateRequest inner; + inner.n_gen = req.n_gen - 1; + inner.stream = true; + inner.do_sample = false; + inner.sampler = sampler; + + // Single source of truth: backend must have called set_effective_gamma + // at attach time. effective_gamma() == 0 means the backend forgot — fail + // loud rather than silently default to max_gamma (the bug class that + // tanked accept_rate from 0.41 to 0.04 in the earlier orchestrator). + const int gamma = module->effective_gamma(); + if (gamma <= 0) { + result.error = "warm_and_decode: module->effective_gamma() == 0 — backend must call set_effective_gamma() during attach"; + io.emit(-1); + return result; + } + + auto t_decode0 = std::chrono::steady_clock::now(); + dflash::common::mtp::MtpChainRunner runner(*module, *target, sampler); + GenerateResult inner_res = runner.run(inner, io, last_tok, prompt_len, gamma); + result.decode_s = std::chrono::duration( + std::chrono::steady_clock::now() - t_decode0).count(); + + if (!inner_res.ok) { + result.error = "warm_and_decode: chain runner failed: " + inner_res.error; + io.emit(-1); + return result; + } + + for (int32_t t : inner_res.tokens) result.tokens.push_back(t); + + const auto & st = runner.stats(); + if (st.total_iters > 0) { + std::fprintf(stderr, + "[mtp_decode] iters=%d proposed=%d accepted=%d emitted=%d accept_rate=%.2f\n", + st.total_iters, st.total_proposed, st.total_accepted, st.total_emitted, + st.total_proposed > 0 + ? (double)st.total_accepted / (double)st.total_proposed : 0.0); + } + + result.ok = true; + return result; +} + +} // namespace mtp +} // namespace dflash::common diff --git a/dflash/src/common/mtp_orchestrator.h b/dflash/src/common/mtp_orchestrator.h new file mode 100644 index 00000000..6ef74ef4 --- /dev/null +++ b/dflash/src/common/mtp_orchestrator.h @@ -0,0 +1,44 @@ +// mtp_orchestrator.h — Generic MTP warm + decode driver. +// +// All compute (prefill, attention, MTP head forward, sampling) goes through +// DFlashTarget::verify_batch and IMtpModule::step_batch — both ggml graphs +// on the backend's device. Orchestrator owns only control flow and a single +// host-side hidden-sequence buffer for warm_head_kv. + +#pragma once + +#include "model_backend.h" +#include "mtp_interface.h" + +namespace dflash::common { + +class DFlashTarget; + +namespace mtp { + +// Drive the full MTP warm + chain decode for one request. +// +// Preconditions: +// - backend != nullptr +// - backend->supports_mtp() returns true (else returns error early) +// - backend->mtp() and backend->dflash_target() return non-null +// - req.prompt is non-empty +// +// Behavior: +// 1. Chunked prefill via DFlashTarget::verify_batch, capturing the +// backbone's per-position pre/post norm hidden states. +// 2. Seed the MTP module with the last hidden + warm head KV across +// all prompt positions. +// 3. Run MtpChainRunner for n_gen decode tokens at the given gamma. +// 4. Stream tokens through io.emit() and append them to +// result.tokens. Emit the terminal -1 sentinel. +// +// Returns a GenerateResult populated with tokens / prefill_s / decode_s. +// On any failure, .ok = false and .error describes the cause (matches +// the daemon log's "err " line). +GenerateResult warm_and_decode(ModelBackend * backend, + const GenerateRequest & req, + const DaemonIO & io); + +} // namespace mtp +} // namespace dflash::common diff --git a/dflash/src/common/step_graph.h b/dflash/src/common/step_graph.h index 5e18f3f6..cdb160cc 100644 --- a/dflash/src/common/step_graph.h +++ b/dflash/src/common/step_graph.h @@ -39,6 +39,20 @@ struct StepGraph { ggml_tensor * hidden_states = nullptr; // draft hidden-only output ggml_tensor * argmax_tokens = nullptr; // [n_tokens] i32, GPU-side argmax of logits ggml_tensor * topk_indices = nullptr; // [K, n_tokens] i32, GPU-side top-K indices + // Post-norm hidden for last token [n_embd] f32. Used by MTP module to + // seed h_prev_0. Populated by build_target_step; null otherwise. + ggml_tensor * last_norm_hidden = nullptr; + // Full post-norm hidden sequence [n_embd, n_tokens] f32. Used by + // warm_head_kv() to read per-position hiddens during prefill. + ggml_tensor * all_norm_hidden = nullptr; + // Pre-final-output-norm hidden — last token [n_embd] f32 and full + // sequence [n_embd, n_tokens] f32. Used by the Qwen3.6 MTP module to + // seed the NextN head's h_prev WITHOUT double-normalising against the + // head's own hnorm (mirror of llama.cpp PR #22673 `t_h_pre_norm`). + // Populated alongside last_/all_norm_hidden when the caller asks for + // capture_all_norm_hidden. + ggml_tensor * last_h_pre_norm = nullptr; + ggml_tensor * all_h_pre_norm = nullptr; // Per-delta-net-layer captures (verify only). std::vector delta_captures; @@ -57,6 +71,10 @@ inline void step_graph_free(StepGraph & sg) { sg.hidden_states = nullptr; sg.argmax_tokens = nullptr; sg.topk_indices = nullptr; + sg.last_norm_hidden = nullptr; + sg.all_norm_hidden = nullptr; + sg.last_h_pre_norm = nullptr; + sg.all_h_pre_norm = nullptr; sg.delta_captures.clear(); } diff --git a/dflash/src/internal.h b/dflash/src/internal.h index 6f5666df..d2fe48b0 100644 --- a/dflash/src/internal.h +++ b/dflash/src/internal.h @@ -498,11 +498,33 @@ struct QwenGraphInputs { bool capture_delta_intermediate = false; // if true, populate out_delta_captures int fa_window = 0; // sliding window for FA layers: 0 = full attention bool last_token_logits_only = false; // if true, only compute logits for last token (prefill optimization) + bool capture_all_norm_hidden = false; // if true, expose full [n_embd, n_tokens] post-norm hidden as a graph output (MTP warmup needs this; non-MTP callers should leave false to avoid pinning ~7.5MB at ubatch=384) ggml_tensor * parent_ids = nullptr; // [n_tokens] i32; tree mode when non-null }; struct QwenGraphOutputs { ggml_tensor * logits; // [vocab, n_tokens] f32 + // Post-norm hidden for the last committed token, shape [n_embd], f32. + // This is the output of out_norm before the LM-head projection. Set as + // ggml_set_output so it remains valid after graph_compute. Used by the + // Qwen3.6 MTP module to seed h_prev_0 for the first NextN head. + ggml_tensor * last_norm_hidden = nullptr; // [n_embd] f32 + // Full post-norm hidden sequence, shape [n_embd, n_tokens] f32. + // Same tensor as last_norm_hidden's parent; set as ggml_set_output so + // the MTP warm_head_kv() can read per-position hiddens during prefill. + ggml_tensor * all_norm_hidden = nullptr; // [n_embd, n_tokens] f32 + // Pre-final-output-norm hidden state. Mirrors llama.cpp PR #22673's + // `t_h_pre_norm` (exposed in src/models/qwen35.cpp before the final + // build_norm). The Qwen3.6 MTP module seeds its NextN head with this + // tensor for the OUTER spec-chain step (h_prev_0): feeding the + // post-output-norm hidden double-normalises against the head's own + // hnorm and compounds the per-depth rejection rate. Last column only; + // shape [n_embd] f32. Populated when capture_all_norm_hidden=true + // (the MTP enables hidden-seq capture and wants both). + ggml_tensor * last_h_pre_norm = nullptr; // [n_embd] f32 + // Full pre-output-norm hidden sequence [n_embd, n_tokens] f32. Same + // capture flag as all_norm_hidden. + ggml_tensor * all_h_pre_norm = nullptr; // [n_embd, n_tokens] f32 // One entry per delta-net layer (48 for qwen35-27b). Only populated when // QwenGraphInputs::capture_delta_intermediate is true. Tensors are graph // views marked as ggml_set_output() so their data persists after diff --git a/dflash/src/qwen35/gguf_target_loader.cpp b/dflash/src/qwen35/gguf_target_loader.cpp index 62e209eb..30779194 100644 --- a/dflash/src/qwen35/gguf_target_loader.cpp +++ b/dflash/src/qwen35/gguf_target_loader.cpp @@ -278,7 +278,13 @@ bool load_target_gguf_partial(const std::string & path, std::string err; const uint32_t n_embd = get_u32_or(gctx, "qwen35.embedding_length", 0); const uint32_t n_ff = get_u32_or(gctx, "qwen35.feed_forward_length", 0); - const uint32_t n_layer= get_u32_or(gctx, "qwen35.block_count", 0); + const uint32_t n_block_raw = get_u32_or(gctx, "qwen35.block_count", 0); + // Qwen3.6 MTP GGUFs append `nextn_predict_layers` extra blocks holding the + // NextN heads (e.g. block_count=65 = 63 backbone + 2 MTP heads). The qwen35 + // MTP loader picks those up separately; here we treat n_layer as backbone + // layers only so the divisibility check + tensor binding ignore the heads. + const uint32_t n_nextn = get_u32_or(gctx, "qwen35.nextn_predict_layers", 0); + const uint32_t n_layer = n_block_raw > n_nextn ? n_block_raw - n_nextn : n_block_raw; const uint32_t n_head = get_u32_or(gctx, "qwen35.attention.head_count",0); const uint32_t n_headkv=get_u32_or(gctx, "qwen35.attention.head_count_kv",0); const uint32_t kl = get_u32_or(gctx, "qwen35.attention.key_length", 0); @@ -552,6 +558,18 @@ bool load_target_gguf_partial(const std::string & path, size_t total = 0; size_t tok_embd_off = 0, tok_embd_sz = 0; ggml_type tok_embd_type = GGML_TYPE_COUNT; + // F32 norm/SSM weights are the easiest way to spot a corrupt GGUF: any + // NaN in them poisons every later layer's logits to NaN, and the only + // user-visible symptom is the cryptic "prefill produced invalid token" + // emitted N seconds later by test_dflash's argmax check. Unsloth's + // Qwen3.6-27B-MTP-Q4_K_M.gguf (May 2026) ships with ~349 corrupt F32 + // tensors (NaN in attn_norm / post_attention_norm / ssm_alpha / ssm_beta + // / ssm_conv1d / *_q_norm / *_k_norm), so we fail-fast here with the + // exact tensor name + count instead of letting the GPU graph silently + // produce nan-everywhere logits. CPU-side mmap scan, no extra GPU work. + int corrupt_f32_count = 0; + std::string first_corrupt_name; + size_t first_corrupt_nan = 0; for (int64_t tid = 0; tid < n_tensors; tid++) { const char * tname = gguf_get_tensor_name(gctx, tid); ggml_tensor * t = ggml_get_tensor(meta_ctx, tname); @@ -573,9 +591,39 @@ bool load_target_gguf_partial(const std::string & path, if (!should_load_target_tensor(tname, plan.layer_begin, plan.layer_end, plan.load_output)) { continue; } + if (gguf_get_tensor_type(gctx, tid) == GGML_TYPE_F32) { + const float * fp = (const float *)((const uint8_t *)mm.addr + off); + const size_t n_elem = sz / sizeof(float); + size_t n_nan = 0; + for (size_t i = 0; i < n_elem; i++) { + // Bit-level NaN test: any float whose exponent is all-1 and + // mantissa is non-zero. Faster + portable vs std::isnan. + uint32_t u; + std::memcpy(&u, fp + i, sizeof(u)); + if ((u & 0x7F800000u) == 0x7F800000u && (u & 0x007FFFFFu) != 0u) n_nan++; + } + if (n_nan > 0) { + if (corrupt_f32_count == 0) { + first_corrupt_name = tname; + first_corrupt_nan = n_nan; + } + corrupt_f32_count++; + } + } ggml_backend_tensor_set(t, (const uint8_t *)mm.addr + off, 0, sz); total += sz; } + if (corrupt_f32_count > 0) { + char buf[384]; + std::snprintf(buf, sizeof(buf), + "GGUF has %d F32 weight tensor(s) with NaN values (first: '%s' with %zu NaN). " + "This GGUF is corrupt — every NaN in a norm/SSM weight propagates to all " + "downstream logits. Re-download or use a different quantization.", + corrupt_f32_count, first_corrupt_name.c_str(), first_corrupt_nan); + set_last_error(buf); + gguf_free(gctx); + return false; + } // ── 4b. Read NVFP4 per-tensor weight scales (optional; 1.0 for non-NVFP4). // diff --git a/dflash/src/qwen35/graph_builders.cpp b/dflash/src/qwen35/graph_builders.cpp index c1f51cdb..ddb40271 100644 --- a/dflash/src/qwen35/graph_builders.cpp +++ b/dflash/src/qwen35/graph_builders.cpp @@ -92,7 +92,8 @@ bool build_target_step( bool capture_delta_intermediate, int fa_window, bool last_token_logits_only, - int kq_stride_pad) { + int kq_stride_pad, + bool capture_all_norm_hidden) { step_graph_free(sg); ggml_init_params ip{}; @@ -135,10 +136,15 @@ bool build_target_step( gi.capture_delta_intermediate = capture_delta_intermediate; gi.fa_window = fa_window; gi.last_token_logits_only = last_token_logits_only; + gi.capture_all_norm_hidden = capture_all_norm_hidden; QwenGraphOutputs go = build_qwen35_graph(sg.ctx, sg.gf, w, cache, gi); if (!go.logits) return false; sg.logits = go.logits; + sg.last_norm_hidden = go.last_norm_hidden; + sg.all_norm_hidden = go.all_norm_hidden; // null unless capture_all_norm_hidden + sg.last_h_pre_norm = go.last_h_pre_norm; // null unless capture_all_norm_hidden + sg.all_h_pre_norm = go.all_h_pre_norm; // null unless capture_all_norm_hidden sg.delta_captures = std::move(go.delta_captures); ggml_set_output(sg.logits); diff --git a/dflash/src/qwen35/graph_builders.h b/dflash/src/qwen35/graph_builders.h index 323e8e3d..7ba24149 100644 --- a/dflash/src/qwen35/graph_builders.h +++ b/dflash/src/qwen35/graph_builders.h @@ -53,7 +53,8 @@ bool build_target_step( bool capture_delta_intermediate = false, int fa_window = 0, bool last_token_logits_only = false, - int kq_stride_pad = KQ_MASK_PAD); + int kq_stride_pad = KQ_MASK_PAD, + bool capture_all_norm_hidden = false); // Full target forward: DDTree tree-verify mode. bool build_target_step_tree( diff --git a/dflash/src/qwen35/qwen35_backend.cpp b/dflash/src/qwen35/qwen35_backend.cpp index 8b08d69e..d8bb95fd 100644 --- a/dflash/src/qwen35/qwen35_backend.cpp +++ b/dflash/src/qwen35/qwen35_backend.cpp @@ -9,7 +9,10 @@ #include "common/sampler.h" #include "common/io_utils.h" #include "common/restore_delta.h" +#include "common/mtp_chain_runner.h" +#include "common/mtp_orchestrator.h" #include "qwen3/qwen3_drafter.h" +#include "qwen35/qwen35_mtp.h" #include "ggml-cuda.h" #include "common/snapshot_backend.h" @@ -69,7 +72,7 @@ bool Qwen35Backend::init() { } std::printf("[target] %s\n", dflash27b_last_error()); - // Load draft + // Load draft (skipped in MTP mode — mtp_gguf_path replaces the draft) if (cfg_.draft_path) { std::string dp(cfg_.draft_path); bool draft_ok = (dp.size() >= 5 && dp.substr(dp.size() - 5) == ".gguf") @@ -90,10 +93,23 @@ bool Qwen35Backend::init() { } } - // Create KV cache - const int max_verify_tokens = cfg_.ddtree_mode - ? std::max(dw_.block_size, cfg_.ddtree_budget + 1) - : dw_.block_size; + // Create KV cache. + // MTP mode: size for max(gamma+1, ddtree_budget+1) verify tokens so the + // speculative verify batch fits even without a DFlash draft block size. + // DFlash mode: existing logic uses dw_.block_size. + int max_verify_tokens = 0; + if (cfg_.mtp_gguf_path) { + const int mtp_gamma_eff = std::max(1, cfg_.mtp_gamma); + const int budget_eff = cfg_.ddtree_mode ? cfg_.ddtree_budget : 0; + max_verify_tokens = std::max(mtp_gamma_eff + 1, budget_eff + 1); + // Ensure at least DFLASH27B_DRAFT_BLOCK_SIZE so internal buffers + // allocated by create_target_cache are sized conservatively. + max_verify_tokens = std::max(max_verify_tokens, DFLASH27B_DRAFT_BLOCK_SIZE); + } else { + max_verify_tokens = cfg_.ddtree_mode + ? std::max(dw_.block_size, cfg_.ddtree_budget + 1) + : dw_.block_size; + } if (!create_target_cache(w_, cfg_.device.max_ctx, max_verify_tokens, target_backend_, cache_, /*prefill_only=*/true)) { std::fprintf(stderr, "cache: %s\n", dflash27b_last_error()); @@ -113,6 +129,13 @@ bool Qwen35Backend::init() { } } + // Init MTP speculator when configured. + if (cfg_.mtp_gguf_path) { + if (!init_mtp_()) { + return false; + } + } + return true; } @@ -408,6 +431,18 @@ DFlashTarget * Qwen35Backend::dflash_target() { return dflash_target_.get(); } +// ── Test/bench integration hooks ──────────────────────────────────────── + +bool Qwen35Backend::ensure_decode_cache(int max_verify_tokens) { + return migrate_prefill_cache(w_, cfg_.device.max_ctx, + max_verify_tokens, + target_backend_, cache_); +} + +ggml_context * Qwen35Backend::tensor_context() const { + return w_.ctx; +} + // ── Shutdown ──────────────────────────────────────────────────────────── void Qwen35Backend::shutdown() { @@ -452,17 +487,19 @@ GenerateResult Qwen35Backend::generate(const GenerateRequest & req, // position-addressed and will be overwritten during prefill. reset_recurrent_state(cache_); - // Prefill - auto t_prefill_start = std::chrono::steady_clock::now(); - const int committed = do_prefill(req.prompt, out_io, req.snap_pos, req.snap_slot); - if (committed < 0) { - result.error = "prefill"; - return result; + // MTP path: delegate to common orchestrator. Cache was already sized in + // init_mtp_() — no per-request migrate (idempotent no-op). + if (supports_mtp()) { + return mtp::warm_and_decode(this, req, io); } - auto t_prefill_end = std::chrono::steady_clock::now(); - result.prefill_s = std::chrono::duration(t_prefill_end - t_prefill_start).count(); - // Decode (speculative) + // DFlash / AR path + auto t_prefill_start = std::chrono::steady_clock::now(); + int committed = do_prefill(req.prompt, out_io, req.snap_pos, req.snap_slot); + if (committed < 0) { result.error = "prefill"; return result; } + result.prefill_s = std::chrono::duration( + std::chrono::steady_clock::now() - t_prefill_start).count(); + if (req.n_gen > 0) { auto t_decode_start = std::chrono::steady_clock::now(); if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io)) { @@ -472,7 +509,6 @@ GenerateResult Qwen35Backend::generate(const GenerateRequest & req, result.decode_s = std::chrono::duration( std::chrono::steady_clock::now() - t_decode_start).count(); } - result.ok = true; return result; } @@ -506,13 +542,30 @@ GenerateResult Qwen35Backend::restore_and_generate(int slot, // only the delta at KV positions [snap_pos, snap_pos + delta.size()). int committed = snap_pos; const int prompt_len = (int)req.prompt.size(); + std::vector all_prefill_hidden; + if (prompt_len > snap_pos) { auto t_prefill_start = std::chrono::steady_clock::now(); std::vector delta = restore_prompt_delta(req.prompt, snap_pos); - committed = do_prefill(delta, out_io, req.snap_pos, req.snap_slot, /*kv_offset=*/snap_pos); - if (committed < 0) { - result.error = "prefill"; - return result; + + if (supports_mtp()) { + // MTP path: prefill delta via DFlashTarget::verify_batch so hidden + // states are captured for warm_mtp_for_prompt_. + committed = do_mtp_prefill_(delta, all_prefill_hidden, /*kv_offset=*/snap_pos); + if (committed < 0) { + result.error = "mtp_prefill"; + return result; + } + if (!warm_mtp_for_prompt_(req.prompt, all_prefill_hidden, cache_.last_tok)) { + result.error = "mtp_warm"; + return result; + } + } else { + committed = do_prefill(delta, out_io, req.snap_pos, req.snap_slot, /*kv_offset=*/snap_pos); + if (committed < 0) { + result.error = "prefill"; + return result; + } } result.prefill_s = std::chrono::duration( std::chrono::steady_clock::now() - t_prefill_start).count(); @@ -526,7 +579,10 @@ GenerateResult Qwen35Backend::restore_and_generate(int slot, // Decode if (req.n_gen > 0) { auto t_decode_start = std::chrono::steady_clock::now(); - if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io)) { + bool ok = supports_mtp() + ? do_mtp_decode_(committed, req.n_gen, result.tokens, io) + : do_spec_decode(committed, req.n_gen, result.tokens, out_io); + if (!ok) { result.error = "decode"; return result; } @@ -975,4 +1031,286 @@ int Qwen35Backend::verify_tree(int committed, const DDTree & tree) { return 0; } +// ── MTP init helper ───────────────────────────────────────────────────────── +// +// Mirrors run_qwen35_mtp_harness lines 728-746: construct Qwen35MtpModule, +// load weights from the fused GGUF sharing the backbone tensor context, +// then attach to the DFlashTarget adapter. +// +// Called from init() when cfg_.mtp_gguf_path != nullptr. + +bool Qwen35Backend::init_mtp_() { + // Ensure the DFlashTarget adapter exists (lazy-built); the MTP module + // needs it for attach() and for reading last_hidden() during decode. + DFlashTarget * target = dflash_target(); + if (!target) { + std::fprintf(stderr, "[mtp] dflash_target() unavailable\n"); + return false; + } + + mtp_module_ = std::make_unique(); + std::string err; + // Use the 3-arg init so MTP tensors are loaded from the MTP GGUF's own + // context (blk.64.* live there, not in the backbone ctx). + if (!mtp_module_->init(cfg_.mtp_gguf_path, target, err)) { + std::fprintf(stderr, "[mtp] init failed: %s\n", err.c_str()); + mtp_module_.reset(); + return false; + } + + if (!mtp_module_->attach(target)) { + std::fprintf(stderr, "[mtp] attach(target) failed\n"); + mtp_module_.reset(); + return false; + } + + // Single source of truth for γ: bind to module at attach time. Orchestrator + // + runner read module->effective_gamma() — no parallel storage. This is + // why the regression in PR #214's earlier orchestrator (which silently used + // module->max_gamma()) cannot happen by construction once this is in place. + mtp_module_->set_effective_gamma(cfg_.mtp_gamma); + + // Pre-size the rollback cache once for the MTP gamma chosen at attach. + // Per momus audit of cubic#3257248868: hoisting out of generate() makes + // the OOM-on-first-request → nullptr ssm_intermediate → segfault path + // unreachable (max_ctx + γ are config-time constants; check return here + // where we can fail backend init cleanly). + const int gamma_eff = (cfg_.mtp_gamma > 0) ? cfg_.mtp_gamma : 3; + if (!migrate_prefill_cache(w_, cfg_.device.max_ctx, + std::max(gamma_eff + 1, DFLASH27B_DRAFT_BLOCK_SIZE), + target_backend_, cache_)) { + std::fprintf(stderr, "[mtp] migrate_prefill_cache failed (max_ctx=%d gamma=%d)\n", + cfg_.device.max_ctx, gamma_eff); + mtp_module_.reset(); + return false; + } + + if (cfg_.mtp_use_topk) { + mtp_module_->set_draft_topk(std::max(1, cfg_.mtp_draft_topk)); + } + + std::printf("[mtp] loaded gamma=%d source=%s\n", + cfg_.mtp_gamma, + cfg_.mtp_use_topk ? "mtp_topk" : "chain"); + std::fflush(stdout); + return true; +} + +// ── MTP warm helper ────────────────────────────────────────────────────────── +// +// Mirrors run_qwen35_mtp_harness lines 783-792: seed the MTP module with the +// backbone's final post-norm hidden after prefill, then warm the head KV cache +// over all prefill positions. +// +// prompt : full prompt token sequence (length N) +// all_prefill_hidden : backbone post-norm hiddens [N * hidden_size], F32 +// prefill_next : backbone argmax at end of prefill (t_{N}) +// +// Returns false on failure; the caller (Phase 4's do_prefill hook or generate) +// should treat this as an error. + +bool Qwen35Backend::warm_mtp_for_prompt_(const std::vector & prompt, + const std::vector & all_prefill_hidden, + int32_t prefill_next) { + if (!mtp_module_) return true; // MTP not configured — no-op + + DFlashTarget * target = dflash_target(); + + // Seed with the backbone's final hidden for the last prefill token. + if (target && target->last_hidden()) { + mtp_module_->set_initial_hidden(target->last_hidden(), target->hidden_size()); + } + + // Warm the head KV cache if we have the full-sequence hiddens. + if (!all_prefill_hidden.empty() && prefill_next >= 0) { + if (!mtp_module_->warm_head_kv(prompt.data(), + static_cast(prompt.size()), + prefill_next, + all_prefill_hidden.data())) { + std::fprintf(stderr, "[mtp] warm_head_kv failed\n"); + return false; + } + } + + return true; +} + +// ── MTP prefill helper ─────────────────────────────────────────────────────── +// +// Routes chunked prefill through DFlashTarget::verify_batch instead of the +// raw build_target_step path in do_prefill. This populates last_hidden_seq +// on the Qwen35DFlashTarget so warm_mtp_for_prompt_ has per-position hidden +// states for warm_head_kv. kv_offset > 0 means we are resuming after a +// snapshot restore (same semantics as do_prefill's kv_offset). +// +// Returns the committed KV position (== kv_offset + tokens.size()) or -1 +// on failure. + +int Qwen35Backend::do_mtp_prefill_(const std::vector & tokens, + std::vector & all_prefill_hidden_out, + int kv_offset) { + const int prompt_len = (int)tokens.size(); + if (prompt_len == 0) return kv_offset; + + DFlashTarget * target = dflash_target(); + if (!target) { + std::fprintf(stderr, "[mtp_prefill] dflash_target unavailable\n"); + return -1; + } + + // Cast to the concrete type to enable hidden-sequence capture. Safe + // because Qwen35Backend always constructs a Qwen35DFlashTarget via + // dflash_target(). + auto * q35target = static_cast(target); + + // Ensure the rollback cache exists (needed for verify_batch's chain-capture + // plumbing and for do_mtp_decode_'s snapshot_kv/restore_kv calls). + const int gamma_eff = (cfg_.mtp_gamma > 0) ? cfg_.mtp_gamma : 3; + migrate_prefill_cache(w_, cfg_.device.max_ctx, + std::max(gamma_eff + 1, DFLASH27B_DRAFT_BLOCK_SIZE), + target_backend_, cache_); + + // Enable full-sequence hidden capture so last_hidden_seq is populated + // after each verify_batch chunk. Switch back to LAST_ROW_ONLY after + // prefill so decode-side chain verifies only pay one-row download cost. + q35target->enable_hidden_seq_capture(true); + q35target->set_hidden_capture_mode(Qwen35DFlashTarget::VerifyCaptureMode::FULL_SEQ); + + const int hidden = target->hidden_size(); + all_prefill_hidden_out.resize((size_t)prompt_len * hidden); + + int prefill_ubatch = 512; + if (const char * s = std::getenv("DFLASH27B_PREFILL_UBATCH")) { + prefill_ubatch = std::max(1, std::atoi(s)); + } + + int committed = kv_offset; + int32_t last_tok = -1; + + for (int start = 0; start < prompt_len;) { + const int n = std::min(prefill_ubatch, prompt_len - start); + const int kv_pos = kv_offset + start; + + std::vector chunk(tokens.begin() + start, + tokens.begin() + start + n); + + // verify_batch handles: embed, positions, mask, compute, argmax, + // last_hidden, last_hidden_seq. It also advances cache_.cur_pos. + if (!target->verify_batch(chunk, kv_pos, last_tok, nullptr)) { + std::fprintf(stderr, "[mtp_prefill] verify_batch failed at kv_pos=%d\n", kv_pos); + q35target->enable_hidden_seq_capture(false); + return -1; + } + + // Collect per-chunk hidden states into all_prefill_hidden_out. + int n_chunk = 0; + const float * h_seq = target->last_hidden_seq(&n_chunk); + if (h_seq && n_chunk == n) { + std::memcpy(all_prefill_hidden_out.data() + (size_t)start * hidden, + h_seq, + sizeof(float) * (size_t)n * hidden); + } else { + std::fprintf(stderr, + "[mtp_prefill] hidden seq missing: expected %d tokens, got %d\n", + n, n_chunk); + // Non-fatal — warm_head_kv will be skipped with an empty hidden buffer. + all_prefill_hidden_out.clear(); + } + + committed = kv_pos + n; + start += n; + } + + // Record last token for do_mtp_decode_ (mirrors do_prefill's cache_.last_tok). + cache_.last_tok = last_tok; + + // Switch to LAST_ROW_ONLY for decode phase — subsequent chain-runner + // verify_batch calls only need hidden_at_pos(base_pos - 1). + q35target->set_hidden_capture_mode(Qwen35DFlashTarget::VerifyCaptureMode::LAST_ROW_ONLY); + + return committed; +} + +// ── MTP decode helper ──────────────────────────────────────────────────────── +// +// Drives MtpChainRunner after prefill. Mirrors harness lines 826-841 (chain +// path only — mtp_topk is Phase 4 out-of-scope). +// +// Emits tokens via io.emit per token (req.stream=true inside the runner), and +// a terminal io.emit(-1) is issued by the runner. The caller must NOT emit +// additional -1 tokens. +// +// Returns false on failure; caller sets result.error = "decode". + +bool Qwen35Backend::do_mtp_decode_(int committed, int n_gen, + std::vector & out_tokens, + const DaemonIO & io) { + if (!mtp_module_) return false; + if (n_gen <= 0) { + io.emit(-1); + return true; + } + + // The argmax produced at the end of prefill — this is the first generated + // token (matches harness line 806: generated.push_back(prefill_next)). + const int32_t prefill_next = cache_.last_tok; + if (prefill_next < 0) { + std::fprintf(stderr, "[mtp_decode] prefill_next invalid (%d)\n", prefill_next); + return false; + } + + // Emit the prefill token immediately. + out_tokens.push_back(prefill_next); + io.emit(prefill_next); + if (IS_EOS_TOK(prefill_next, w_)) { + io.emit(-1); + return true; + } + if (n_gen == 1) { + io.emit(-1); + return true; + } + + DFlashTarget * target = dflash_target(); + if (!target) { + std::fprintf(stderr, "[mtp_decode] dflash_target unavailable\n"); + return false; + } + + const int gamma = (cfg_.mtp_gamma > 0) ? cfg_.mtp_gamma : 3; + + // MtpChainRunner with stream=true: the runner calls io.emit() per accepted + // token and issues io.emit(-1) on completion. No double-emit needed here. + GenerateRequest req_inner; + req_inner.n_gen = n_gen - 1; // prefill_next already consumed + req_inner.stream = true; + req_inner.do_sample = false; + req_inner.sampler = sampler_; + + mtp::MtpChainRunner runner(*mtp_module_, *target, sampler_); + GenerateResult res = runner.run(req_inner, io, prefill_next, committed, gamma); + + if (!res.ok) { + std::fprintf(stderr, "[mtp_decode] chain runner failed: %s\n", res.error.c_str()); + return false; + } + + // Append tokens the runner emitted (runner already streamed them via io.emit). + for (int32_t t : res.tokens) { + out_tokens.push_back(t); + } + + // Log acceptance stats (mirrors harness stderr line 1109). + const auto & st = runner.stats(); + if (st.total_iters > 0) { + std::fprintf(stderr, + "[mtp_decode] iters=%d proposed=%d accepted=%d emitted=%d accept_rate=%.2f\n", + st.total_iters, st.total_proposed, st.total_accepted, st.total_emitted, + st.total_proposed > 0 + ? (double)st.total_accepted / (double)st.total_proposed : 0.0); + } + + return true; +} + } // namespace dflash::common diff --git a/dflash/src/qwen35/qwen35_backend.h b/dflash/src/qwen35/qwen35_backend.h index d87f5f0b..230f1c37 100644 --- a/dflash/src/qwen35/qwen35_backend.h +++ b/dflash/src/qwen35/qwen35_backend.h @@ -12,6 +12,7 @@ #pragma once #include "common/model_backend.h" +#include "common/backend_factory.h" #include "common/dflash_target.h" #include "common/device_placement.h" #include "step_graph.h" @@ -19,6 +20,7 @@ #include "dflash_feature_ring.h" #include "internal.h" // TargetWeights, TargetCache, DraftWeights, PrefixSnapshot #include "qwen3/qwen3_drafter.h" // DrafterContext, load_drafter, free_drafter, drafter_score_and_compress +#include "qwen35/qwen35_mtp.h" // Qwen35MtpModule #include "ggml.h" #include "ggml-backend.h" @@ -27,6 +29,7 @@ #include #include #include +#include namespace dflash::common { @@ -55,6 +58,15 @@ struct Qwen35Config { float ddtree_temp = 1.0f; bool ddtree_chain_seed = true; bool use_feature_mirror = false; + + // MTP (Multi-Token Prediction) speculator — mutually exclusive with draft. + // mtp_gguf_path != nullptr is the MTP-active sentinel (set by backend_factory + // based on MtpSource). For MtpSource::Native it is set to target_path; + // for MtpSource::ExternalDrafter it is the external GGUF path. + const char * mtp_gguf_path = nullptr; // path to GGUF containing MTP tensors + int mtp_gamma = 0; // max speculation depth + bool mtp_use_topk = false; // false = chain (default), true = mtp_topk + int mtp_draft_topk = 1; // top-k for mtp_topk mode }; // ── Backend class ─────────────────────────────────────────────────────── @@ -107,6 +119,15 @@ class Qwen35Backend : public ModelBackend { bool supports_dflash_spec_decode() const override { return true; } DFlashTarget * dflash_target() override; + // Test-only: not part of the ModelBackend interface. Used by test_dflash + // harness for direct decode-cache + tensor-context introspection. + bool ensure_decode_cache(int max_verify_tokens); + ggml_context * tensor_context() const; + + // MTP module accessors (ModelBackend interface). + bool supports_mtp() const override { return mtp_module_ != nullptr; } + mtp::IMtpModule * mtp() override { return mtp_module_.get(); } + void shutdown() override; private: @@ -156,6 +177,9 @@ class Qwen35Backend : public ModelBackend { // ── DFlashTarget adapter (lazy-built) ──────────────────────────── std::unique_ptr dflash_target_; + // ── MTP speculator (optional, set when cfg_.mtp_gguf_path != nullptr) ── + std::unique_ptr mtp_module_; + // ── Internal helpers ───────────────────────────────────────────── // Prefill a prompt and return the number of tokens committed to KV. // kv_offset > 0 resumes from a restored snapshot: tokens are placed at @@ -180,6 +204,31 @@ class Qwen35Backend : public ModelBackend { // DDTree tree-mode verify. int verify_tree(int committed, const DDTree & tree); + + // MTP init: load and attach the Qwen35MtpModule. Called from init() when + // cfg_.mtp_gguf_path is set. Returns false on failure. + bool init_mtp_(); + + // MTP warm: seed the head KV cache after prefill. Mirrors harness lines + // 783-792: set_initial_hidden + warm_head_kv. prefill_next is the argmax + // token produced by the last prefill chunk. + bool warm_mtp_for_prompt_(const std::vector & prompt, + const std::vector & all_prefill_hidden, + int32_t prefill_next); + + // MTP prefill: chunked prefill via DFlashTarget::verify_batch, collecting + // all_prefill_hidden for each chunk so warm_mtp_for_prompt_ can seed the + // head KV cache. Returns committed KV position (>= 0) or -1 on error. + // kv_offset > 0 resumes from a snapshot (same semantics as do_prefill). + int do_mtp_prefill_(const std::vector & tokens, + std::vector & all_prefill_hidden_out, + int kv_offset = 0); + + // MTP decode: drive MtpChainRunner after prefill. Emits tokens via io.emit + // (including the terminal -1). Returns false on error. + bool do_mtp_decode_(int committed, int n_gen, + std::vector & out_tokens, + const DaemonIO & io); }; } // namespace dflash::common diff --git a/dflash/src/qwen35/qwen35_daemon.cpp b/dflash/src/qwen35/qwen35_daemon.cpp index d1a14a91..ce18bd4a 100644 --- a/dflash/src/qwen35/qwen35_daemon.cpp +++ b/dflash/src/qwen35/qwen35_daemon.cpp @@ -30,6 +30,28 @@ int run_qwen35_daemon(const Qwen35DaemonArgs & args) { cfg.ddtree_temp = args.ddtree_temp; cfg.ddtree_chain_seed = args.ddtree_chain_seed; cfg.use_feature_mirror = args.use_feature_mirror; + cfg.mtp_gamma = args.mtp_gamma; + cfg.mtp_use_topk = args.mtp_use_topk; + cfg.mtp_draft_topk = args.mtp_draft_topk; + + // Resolve MtpSource to the mtp_gguf_path sentinel that Qwen35Backend expects. + switch (args.mtp_source) { + case MtpSource::Native: + cfg.mtp_gguf_path = args.target_path; + break; + case MtpSource::ExternalDrafter: + cfg.mtp_gguf_path = args.mtp_gguf_path; + break; + case MtpSource::Auto: + cfg.mtp_gguf_path = dflash::common::gguf_contains_mtp_tensors(args.target_path) + ? args.target_path + : nullptr; + break; + case MtpSource::None: + default: + cfg.mtp_gguf_path = nullptr; + break; + } Qwen35Backend backend(cfg); if (!backend.init()) return 1; diff --git a/dflash/src/qwen35/qwen35_daemon.h b/dflash/src/qwen35/qwen35_daemon.h index 2135eca3..8e8f6be8 100644 --- a/dflash/src/qwen35/qwen35_daemon.h +++ b/dflash/src/qwen35/qwen35_daemon.h @@ -5,6 +5,7 @@ #pragma once +#include "common/backend_factory.h" #include "device_placement.h" #include @@ -34,6 +35,14 @@ struct Qwen35DaemonArgs { float ddtree_temp = 1.0f; bool ddtree_chain_seed = true; bool use_feature_mirror = false; + + // MTP (Multi-Token Prediction) speculator — mutually exclusive with draft. + // The daemon uses BackendArgs directly; these fields mirror BackendArgs. + MtpSource mtp_source = MtpSource::None; + const char * mtp_gguf_path = nullptr; // required only for ExternalDrafter + int mtp_gamma = 0; // max speculation depth + bool mtp_use_topk = false; // false = chain, true = mtp_topk + int mtp_draft_topk = 1; // top-k for mtp_topk mode }; // Run the qwen35 daemon loop. Returns 0 on clean exit, 1 on init failure. diff --git a/dflash/src/qwen35/qwen35_dflash_target.cpp b/dflash/src/qwen35/qwen35_dflash_target.cpp index 65713d1b..51f2435a 100644 --- a/dflash/src/qwen35/qwen35_dflash_target.cpp +++ b/dflash/src/qwen35/qwen35_dflash_target.cpp @@ -1,14 +1,66 @@ // Qwen35DFlashTarget — DFlashTarget adapter for qwen35 hybrid models. #include "qwen35_dflash_target.h" +#include "common/ddtree.h" #include "graph_builders.h" #include "step_graph.h" #include "attn_masks.h" +#include "device_runtime.h" // cudaStream_t, cudaMemcpyAsync, cudaMemcpy2DAsync + +#include +#include +#include + +// ggml-cuda dequantize helper (used to widen F16/Q8_0 ssm_intermediate slots +// back to F32 for cache_.ssm_state). Same trick as test_dflash.cpp and +// dflash_feature_ring.cpp — the symbol lives in ggml-cuda/convert.cuh and has +// no public header, so forward-declare the typedef here. +using to_fp32_cuda_t = void (*)(const void *, float *, int64_t, cudaStream_t); +extern "C++" to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type); namespace dflash::common { +#ifdef DFLASH_VERIFY_PROFILE +// Per-call profiler for verify_batch. Enabled by -DDFLASH_VERIFY_PROFILE=1. +// Host-side wall-clock around ggml_backend_* calls IS the GPU+sync latency +// because every set/get/compute internally calls cudaStreamSynchronize. +namespace { +inline bool verify_profile_enabled() { + static const bool on = (std::getenv("DFLASH_VERIFY_PROFILE") != nullptr); + return on; +} + +using vprof_clock = std::chrono::steady_clock; +inline double vprof_ms_since(vprof_clock::time_point t0) { + auto t1 = vprof_clock::now(); + return std::chrono::duration(t1 - t0).count(); +} +} // namespace +#endif // DFLASH_VERIFY_PROFILE + Qwen35DFlashTarget::~Qwen35DFlashTarget() { +#ifdef DFLASH_VERIFY_PROFILE + if (verify_profile_enabled() && vprof_n_calls_ > 0) { + const double inv = 1.0 / (double)vprof_n_calls_; + std::fprintf(stderr, + "[verify_prof summary calls=%lld " + "avg_set=%.3f avg_compute=%.3f avg_get_hidden=%.3f " + "avg_get_hpre=%.3f avg_get_argmax=%.3f avg_total=%.3f " + "(sum_set=%.1f sum_compute=%.1f sum_get_hidden=%.1f " + "sum_get_hpre=%.1f sum_get_argmax=%.1f sum_total=%.1f ms)]\n", + (long long)vprof_n_calls_, + vprof_sum_set_ * inv, + vprof_sum_compute_ * inv, + vprof_sum_get_hidden_ * inv, + vprof_sum_get_hpre_ * inv, + vprof_sum_get_argmax_ * inv, + vprof_sum_total_ * inv, + vprof_sum_set_, vprof_sum_compute_, vprof_sum_get_hidden_, + vprof_sum_get_hpre_, vprof_sum_get_argmax_, vprof_sum_total_); + } +#endif // DFLASH_VERIFY_PROFILE step_graph_destroy(proj_sg_); + if (rollback_stream_) { cudaStreamDestroy(rollback_stream_); rollback_stream_ = nullptr; } } Qwen35DFlashTarget::Qwen35DFlashTarget( @@ -35,25 +87,61 @@ bool Qwen35DFlashTarget::verify_batch( const int hidden = w_.n_embd; const bool need_mask = (kq_stride_pad_ > KQ_MASK_PAD) || (n_tokens > 1); + // Per-position DeltaNet intermediate capture is only safe when: + // 1. The caller (MTP chain runner) opted in via enable_chain_capture(true). + // 2. The cache buffers actually exist (migrate_prefill_cache ran). + // 3. n_tokens fits in the pre-allocated cache. The conv_input cache is + // [(d_conv-1) + max_verify_tokens, conv_ch, 1] and the in-graph + // ggml_view_3d into it asserts when n_tokens > max_verify_tokens + // (e.g. 512-token prefill chunks vs 16-slot cache). The + // ssm_intermediate ggml_cpy also requires n_tokens == max_verify_tokens + // after the matching dst-side view (see qwen35_target_graph.cpp). + int max_verify_tokens = 0; + if (!cache_.ssm_intermediate.empty() && cache_.ssm_intermediate[0] != nullptr) { + max_verify_tokens = (int)cache_.ssm_intermediate[0]->ne[3]; + } + const bool capture_intermediate = + chain_capture_enabled_ && + max_verify_tokens > 0 && + n_tokens <= max_verify_tokens; + if (!build_target_step(sg_, w_, cache_, backend_, /*kv_start=*/base_pos, n_tokens, need_mask, /*capture=*/true, - /*capture_delta_intermediate=*/false, + /*capture_delta_intermediate=*/capture_intermediate, fa_window_, /*last_token_logits_only=*/false, - kq_stride_pad_)) { - std::fprintf(stderr, "verify_batch: build_target_step failed (base=%d n=%d)\n", base_pos, n_tokens); + kq_stride_pad_, + /*capture_all_norm_hidden=*/capture_hidden_seq_)) { + std::fprintf(stderr, "verify_batch: build_target_step failed (base=%d n=%d)\n", + base_pos, n_tokens); return false; } +#ifdef DFLASH_VERIFY_PROFILE + // Per-call profiling state (DFLASH_VERIFY_PROFILE=1). + const bool vprof_on = verify_profile_enabled(); + double vp_set = 0.0, vp_compute = 0.0; + double vp_get_hidden = 0.0, vp_get_hpre = 0.0, vp_get_argmax = 0.0; + vprof_clock::time_point vp_t_total = + vprof_on ? vprof_clock::now() : vprof_clock::time_point{}; + vprof_clock::time_point vp_t0; +#endif // DFLASH_VERIFY_PROFILE + // Embed input tokens and fill positions. std::vector embed((size_t)n_tokens * hidden); if (!w_.embedder.embed(tokens.data(), n_tokens, embed.data())) { std::fprintf(stderr, "verify_batch: embed failed (n=%d)\n", n_tokens); return false; } +#ifdef DFLASH_VERIFY_PROFILE + vp_t0 = vprof_on ? vprof_clock::now() : vprof_clock::time_point{}; +#endif ggml_backend_tensor_set(sg_.inp_embed, embed.data(), 0, sizeof(float) * embed.size()); +#ifdef DFLASH_VERIFY_PROFILE + if (vprof_on) vp_set += vprof_ms_since(vp_t0); +#endif // Qwen35 uses interleaved positions: 4 ints per token. std::vector pos(4 * n_tokens); @@ -63,10 +151,19 @@ bool Qwen35DFlashTarget::verify_batch( pos[4 * i + 2] = base_pos + i; pos[4 * i + 3] = 0; } +#ifdef DFLASH_VERIFY_PROFILE + vp_t0 = vprof_on ? vprof_clock::now() : vprof_clock::time_point{}; +#endif ggml_backend_tensor_set(sg_.positions, pos.data(), 0, sizeof(int32_t) * pos.size()); +#ifdef DFLASH_VERIFY_PROFILE + if (vprof_on) vp_set += vprof_ms_since(vp_t0); +#endif - // Fill causal attention mask when present. + // Populate the causal attention mask. The mask buffer is freshly allocated + // by build_target_step (uninitialized memory); without this set, attention + // reads garbage and ggml_argmax over the resulting logits returns -1 for + // non-last positions, breaking the chain runner's recommit path. if (sg_.attn_mask) { const int win_start = (fa_window_ > 0 && base_pos > fa_window_) ? (base_pos - fa_window_) : 0; @@ -75,17 +172,37 @@ bool Qwen35DFlashTarget::verify_batch( const int kv_pad_override = (int)sg_.attn_mask->ne[0]; build_causal_mask(mask_buf, kv_len, n_tokens, base_pos, kq_stride_pad_, win_start, kv_pad_override); +#ifdef DFLASH_VERIFY_PROFILE + vp_t0 = vprof_on ? vprof_clock::now() : vprof_clock::time_point{}; +#endif ggml_backend_tensor_set(sg_.attn_mask, mask_buf.data(), 0, sizeof(uint16_t) * mask_buf.size()); +#ifdef DFLASH_VERIFY_PROFILE + if (vprof_on) vp_set += vprof_ms_since(vp_t0); +#endif } + // Mask was already filled earlier (see the mask-fill block above); no + // duplicate fill here. Just compute and profile. +#ifdef DFLASH_VERIFY_PROFILE + vp_t0 = vprof_on ? vprof_clock::now() : vprof_clock::time_point{}; +#endif auto st = ggml_backend_graph_compute(backend_, sg_.gf); +#ifdef DFLASH_VERIFY_PROFILE + if (vprof_on) vp_compute = vprof_ms_since(vp_t0); +#endif if (st != GGML_STATUS_SUCCESS) { - std::fprintf(stderr, "verify_batch: compute failed (status=%d)\n", (int)st); + std::fprintf(stderr, "verify_batch: compute failed (status=%d base_pos=%d n_tokens=%d)\n", + (int)st, base_pos, n_tokens); return false; } - // Read argmax results from GPU. + // Read argmax for every position. The chain runner needs all_argmax to + // accept/reject per-position drafts; for non-chain callers (all_argmax + // null) we still need last_tok. Single read covers both. +#ifdef DFLASH_VERIFY_PROFILE + vp_t0 = vprof_on ? vprof_clock::now() : vprof_clock::time_point{}; +#endif std::vector argmax_buf(n_tokens); ggml_backend_tensor_get(sg_.argmax_tokens, argmax_buf.data(), 0, sizeof(int32_t) * n_tokens); @@ -94,11 +211,425 @@ bool Qwen35DFlashTarget::verify_batch( if (all_argmax) { *all_argmax = std::move(argmax_buf); } +#ifdef DFLASH_VERIFY_PROFILE + if (vprof_on) vp_get_argmax = vprof_ms_since(vp_t0); +#endif + + // Copy the last token's post-norm hidden to a CPU buffer so the MTP module + // can call last_hidden() before the next graph_compute overwrites it. + if (sg_.last_norm_hidden) { + last_hidden_cpu_.resize(hidden); +#ifdef DFLASH_VERIFY_PROFILE + vp_t0 = vprof_on ? vprof_clock::now() : vprof_clock::time_point{}; +#endif + ggml_backend_tensor_get(sg_.last_norm_hidden, last_hidden_cpu_.data(), + 0, sizeof(float) * hidden); +#ifdef DFLASH_VERIFY_PROFILE + if (vprof_on) vp_get_hidden += vprof_ms_since(vp_t0); +#endif + } + + // Copy the full [n_tokens, n_embd] post-norm hidden sequence so the + // Qwen3.6 MTP warm_head_kv() can read per-position hiddens during prefill. + // + // R8 audit (Phase A): this branch overwrites last_hidden_seq_cpu_ via + // .resize() + tensor_get on EVERY verify_batch call — including the + // recommit path at mtp_chain_runner.cpp:206 (recommit calls verify_batch + // which lands here). So `hidden_at_pos(base_pos-1)` on the chain's next + // iteration reads the fresh hiddens from the recommit, not a stale slice. + // Hypothesis from the Phase A brief (stale recommit hiddens as cause of + // the 71.6% per-step accept ceiling) is NOT supported by this code path. + if (sg_.all_norm_hidden) { + // LAST_ROW_ONLY (decode mode): download only row n_tokens-1 of the + // [n_tokens, n_embd] hidden tensors. The chain's only consumer is + // hidden_at_pos(base_pos - 1) on the NEXT verify, where the chain's + // base_pos = (this verify's base_pos + n_tokens), so it asks for + // abs_pos = (base_pos + n_tokens - 1) — exactly the row we keep. + // We stash that single row at offset 0 of last_hidden_seq_cpu_ and + // set last_verify_chunk_start_ = base_pos + n_tokens - 1, so the + // hidden_at_pos() accessor's `rel = abs_pos - chunk_start` formula + // resolves to 0 and returns the right pointer. + // + // FULL_SEQ (prefill mode): unchanged — download the whole sequence + // because warm_head_kv() / last_hidden_seq() walks every position. + const bool last_row_only = + (capture_mode_ == VerifyCaptureMode::LAST_ROW_ONLY) && n_tokens > 0; + const int rows_to_copy = last_row_only ? 1 : n_tokens; + const size_t src_row_off = last_row_only ? (size_t)(n_tokens - 1) : 0; + + last_hidden_seq_cpu_.resize((size_t)rows_to_copy * hidden); +#ifdef DFLASH_VERIFY_PROFILE + vp_t0 = vprof_on ? vprof_clock::now() : vprof_clock::time_point{}; +#endif + ggml_backend_tensor_get(sg_.all_norm_hidden, last_hidden_seq_cpu_.data(), + src_row_off * hidden * sizeof(float), + sizeof(float) * (size_t)rows_to_copy * hidden); +#ifdef DFLASH_VERIFY_PROFILE + if (vprof_on) vp_get_hidden += vprof_ms_since(vp_t0); +#endif + last_hidden_seq_n_ = rows_to_copy; + last_verify_chunk_start_ = last_row_only + ? (base_pos + n_tokens - 1) + : base_pos; + // Mirror the PRE-final-output-norm sequence (set by the graph when + // capture_all_norm_hidden is on). hidden_at_pos_pre_norm() reads + // this; the Qwen3.6 MTP chain seeds h_prev_0 with it (PR #22673 + // `t_h_pre_norm`). If the graph did not expose the pre-norm + // tensor (e.g. older graph builds, defensive fallback), clear the + // buffer so the accessor returns nullptr and the caller falls + // back to the post-norm tensor. + if (sg_.all_h_pre_norm) { + last_hidden_seq_pre_norm_cpu_.resize((size_t)rows_to_copy * hidden); +#ifdef DFLASH_VERIFY_PROFILE + vp_t0 = vprof_on ? vprof_clock::now() : vprof_clock::time_point{}; +#endif + ggml_backend_tensor_get(sg_.all_h_pre_norm, + last_hidden_seq_pre_norm_cpu_.data(), + src_row_off * hidden * sizeof(float), + sizeof(float) * (size_t)rows_to_copy * hidden); +#ifdef DFLASH_VERIFY_PROFILE + if (vprof_on) vp_get_hpre = vprof_ms_since(vp_t0); +#endif + } else { + last_hidden_seq_pre_norm_cpu_.clear(); + } + } else { + last_hidden_seq_n_ = 0; + last_hidden_seq_pre_norm_cpu_.clear(); + } cache_.cur_pos = base_pos + n_tokens; + + // Topology is owned by the caller via capture_topology_for_chain(); if + // capture fired without prior topology, invalidate to be defensive. + if (!capture_intermediate) { + last_tree_base_pos_ = -1; + } + +#ifdef DFLASH_VERIFY_PROFILE + if (vprof_on) { + const double vp_total = vprof_ms_since(vp_t_total); + std::fprintf(stderr, + "[verify_prof n_tokens=%d base_pos=%d set=%.3f compute=%.3f " + "get_hidden=%.3f get_hpre=%.3f get_argmax=%.3f total=%.3f (ms)]\n", + n_tokens, base_pos, vp_set, vp_compute, + vp_get_hidden, vp_get_hpre, vp_get_argmax, vp_total); + vprof_sum_set_ += vp_set; + vprof_sum_compute_ += vp_compute; + vprof_sum_get_hidden_ += vp_get_hidden; + vprof_sum_get_hpre_ += vp_get_hpre; + vprof_sum_get_argmax_ += vp_get_argmax; + vprof_sum_total_ += vp_total; + vprof_n_calls_++; + } +#endif // DFLASH_VERIFY_PROFILE return true; } +bool Qwen35DFlashTarget::verify_tree( + const std::vector & flat_tokens, + const DDTree & tree, + int base_pos, + std::vector & out_argmax, + std::vector * out_logits) { + const int N = (int)flat_tokens.size(); + if (N <= 0) return false; + if (N != 1 + tree.n_nodes) return false; + + // Degenerate single-token tree: cheap fast path through verify_batch. + if (tree.n_nodes == 0) { + int32_t last_tok = -1; + std::vector all_argmax; + if (!verify_batch(flat_tokens, base_pos, last_tok, &all_argmax)) { + return false; + } + out_argmax = std::move(all_argmax); + if (out_logits) out_logits->clear(); + return true; + } + + // Real tree verify — body lifted from test_dflash.cpp:3140-3231 (the + // ddtree-verify branch of run_qwen35_mtp_harness's spec-decode loop) + // minus the walk/commit policy (the harness keeps that). + // + // Stage 3: build_target_step_tree below uses capture_delta_intermediate + // = true (see graph_builders.cpp), so cache_.ssm_intermediate[il] holds + // the per-DFS-slot SSM state and cache_.conv_input_cache[il] holds the + // full conv window. restore_kv_at_dfs() consumes these on partial + // accept to undo rejected siblings before the next iteration. + const int hidden = w_.n_embd; + + if (!build_target_step_tree(sg_, w_, cache_, backend_, + /*kv_start=*/base_pos, /*n_tokens=*/N, + fa_window_, kq_stride_pad_)) { + std::fprintf(stderr, + "[Qwen35DFlashTarget] verify_tree: build_target_step_tree failed: base_pos=%d N=%d\n", + base_pos, N); + return false; + } + + // Embed all N flat tokens (root at slot 0, DFS-ordered tree nodes 1..N-1). + std::vector tree_embed((size_t)hidden * N, 0.0f); + if (!w_.embedder.embed(flat_tokens.data(), N, tree_embed.data())) { + return false; + } + ggml_backend_tensor_set(sg_.inp_embed, tree_embed.data(), 0, + sizeof(float) * (size_t)hidden * N); + + // M-RoPE axis-major positions: root at base_pos, node i at base_pos + depths[i-1]. + // Mirrors the harness pos4 layout (4 ints per axis, axis-major). + std::vector pos4(4 * N, 0); + for (int i = 0; i < N; i++) { + const int p = base_pos + (i == 0 ? 0 : tree.depths[i - 1]); + pos4[0 * N + i] = p; + pos4[1 * N + i] = p; + pos4[2 * N + i] = p; + pos4[3 * N + i] = 0; + } + ggml_backend_tensor_set(sg_.positions, pos4.data(), 0, + sizeof(int32_t) * 4 * N); + + // Ancestor-only attention mask. build_target_step_tree allocated + // sg.attn_mask with kv_pad = align_up(cache.max_ctx + N, kq_stride_pad); + // pin build_tree_mask's kv stride to the same value via kv_pad_override. + const int win_start = (fa_window_ > 0 && base_pos > fa_window_) + ? (base_pos - fa_window_) : 0; + std::vector mask_buf; + build_tree_mask(tree, base_pos, mask_buf, kq_stride_pad_, + win_start, /*kv_pad_override=*/cache_.max_ctx + N); + ggml_backend_tensor_set(sg_.attn_mask, mask_buf.data(), 0, + sizeof(uint16_t) * mask_buf.size()); + + // parent_ids: root is -1; node i (1..n_nodes) points to its tree parent. + std::vector parent_ids(N, 0); + parent_ids[0] = -1; + for (int i = 1; i < N; i++) parent_ids[i] = (int32_t)tree.parents[i]; + ggml_backend_tensor_set(sg_.parent_ids, parent_ids.data(), 0, + sizeof(int32_t) * N); + + auto st = ggml_backend_graph_compute(backend_, sg_.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, + "[Qwen35DFlashTarget] verify_tree: graph_compute failed: base_pos=%d N=%d status=%d\n", + base_pos, N, (int)st); + return false; + } + + // Read argmax for all N tree positions. + out_argmax.resize(N); + ggml_backend_tensor_get(sg_.argmax_tokens, out_argmax.data(), 0, + sizeof(int32_t) * N); + + // Optional: pull the full [N × vocab] logits. + if (out_logits) { + const int vocab = sg_.logits ? (int)sg_.logits->ne[0] : 0; + if (vocab <= 0) { + out_logits->clear(); + } else { + out_logits->resize((size_t)N * vocab); + ggml_backend_tensor_get(sg_.logits, out_logits->data(), 0, + sizeof(float) * (size_t)N * vocab); + } + } + + cache_.cur_pos = base_pos + N; + + // Stage 3: cache the tree topology so restore_kv_at_dfs() can locate the + // accepted slots, compute the post-rollback cur_pos, and walk ancestry + // for the conv-window K-1 rollback on sibling-walk accept paths. + last_tree_base_pos_ = base_pos; + last_tree_n_nodes_ = tree.n_nodes; + last_tree_parents_.assign(tree.parents.begin(), tree.parents.end()); + last_tree_depths_.assign(tree.depths.begin(), tree.depths.end()); + + return true; +} + +bool Qwen35DFlashTarget::restore_kv_at_dfs(const std::vector & accepted_dfs) { + if (last_tree_base_pos_ < 0) { + std::fprintf(stderr, + "[Qwen35DFlashTarget] restore_kv_at_dfs called before any verify_tree\n"); + return false; + } + if (accepted_dfs.empty()) { + std::fprintf(stderr, + "[Qwen35DFlashTarget] restore_kv_at_dfs: empty accepted_dfs\n"); + return false; + } + if (accepted_dfs[0] != 0) { + std::fprintf(stderr, + "[Qwen35DFlashTarget] restore_kv_at_dfs: accepted_dfs[0]=%d, expected 0 (root)\n", + accepted_dfs[0]); + return false; + } + const int commit_n = (int)accepted_dfs.size(); // includes root + const int rollback_dfs = accepted_dfs[commit_n - 1]; // deepest accepted DFS slot + const int N = 1 + last_tree_n_nodes_; + if (rollback_dfs < 0 || rollback_dfs >= N) { + std::fprintf(stderr, + "[Qwen35DFlashTarget] restore_kv_at_dfs: rollback_dfs=%d out of range [0,%d)\n", + rollback_dfs, N); + return false; + } + // Detect pure-chain walk (accepted[i] == i for every i in the prefix). + // Hot path; lets us short-circuit the per-conv-column gather + KV compaction. + bool walked_sibling = false; + for (int i = 0; i < commit_n; i++) { + if (accepted_dfs[i] != i) { walked_sibling = true; break; } + } + + const int n_delta = (int)cache_.ssm_intermediate.size(); + // Bug #3: dedicated stream so the rollback copies don't serialize with + // the default stream (e.g. ggml backend compute, host syncs). + if (!rollback_stream_) { + if (cudaStreamCreate(&rollback_stream_) != cudaSuccess) { + rollback_stream_ = nullptr; // fall back to default + } + } + cudaStream_t stream = rollback_stream_; + for (int il = 0; il < n_delta; il++) { + ggml_tensor * ssm_inter = cache_.ssm_intermediate[il]; + ggml_tensor * conv_in = cache_.conv_input_cache[il]; + if (!ssm_inter || !conv_in) { + std::fprintf(stderr, + "[Qwen35DFlashTarget] restore_kv_at_dfs: missing capture layer %d " + "(ssm_inter=%p conv_in=%p) — was verify_tree run with " + "capture_delta_intermediate=true?\n", + il, (void*)ssm_inter, (void*)conv_in); + return false; + } + // SSM state rollback (dequant ssm_intermediate[rollback_dfs] → ssm_state[il]). + const size_t ssm_elems = + (size_t)cache_.ssm_state[il]->ne[0] * + (size_t)cache_.ssm_state[il]->ne[1] * + (size_t)cache_.ssm_state[il]->ne[2]; + const size_t ssm_src_off = + (size_t)rollback_dfs * ssm_inter->nb[3]; + const void * ssm_src = (const char *)ssm_inter->data + ssm_src_off; + ggml_get_to_fp32_cuda(ssm_inter->type)( + ssm_src, (float *)cache_.ssm_state[il]->data, + (int64_t)ssm_elems, stream); + + // Conv rollback: copy the K-1 most recent inputs along the accepted + // path's ANCESTRY (not DFS order). Mirror test_dflash.cpp:3395-3437. + const int K_conv = 4; + const int row_cnt = (int)conv_in->ne[1]; + const size_t elt = ggml_element_size(conv_in); + const size_t dpitch = (K_conv - 1) * elt; + const size_t spitch = conv_in->nb[1]; + if (!walked_sibling) { + // Fast path: K_conv-1 = 3 contiguous slots ending at rollback_dfs. + const int conv_off = rollback_dfs + 1; + const void * conv_src = (const char *)conv_in->data + (size_t)conv_off * elt; + cudaError_t ce = cudaMemcpy2DAsync( + cache_.conv_state[il]->data, dpitch, + conv_src, spitch, + (K_conv - 1) * elt, row_cnt, + cudaMemcpyDeviceToDevice, stream); + if (ce != cudaSuccess) { + std::fprintf(stderr, + "[Qwen35DFlashTarget] restore_kv_at_dfs conv fast il=%d: %s\n", + il, cudaGetErrorString(ce)); + return false; + } + } else { + int virt[K_conv - 1]; + virt[K_conv - 2] = rollback_dfs; + for (int k = K_conv - 3; k >= 0; k--) { + const int prev = virt[k + 1]; + virt[k] = (prev >= 0) ? (int)last_tree_parents_[prev] : (prev - 1); + } + for (int k = 0; k < K_conv - 1; k++) { + const int sx_slot = (K_conv - 1) + virt[k]; + const void * src_col = + (const char *)conv_in->data + (size_t)sx_slot * elt; + char * dst_col = + (char *)cache_.conv_state[il]->data + (size_t)k * elt; + cudaError_t ce = cudaMemcpy2DAsync(dst_col, dpitch, + src_col, spitch, + elt, row_cnt, + cudaMemcpyDeviceToDevice, stream); + if (ce != cudaSuccess) { + std::fprintf(stderr, + "[Qwen35DFlashTarget] restore_kv_at_dfs conv col il=%d k=%d: %s\n", + il, k, cudaGetErrorString(ce)); + return false; + } + } + } + } + + // Full-attention KV compaction: verify_tree wrote K/V at slots + // [base..base+N-1] in DFS order. Bug #3: collapse the per-head inner + // loop into one cudaMemcpy2DAsync (pitch=nb[2], height=n_kv) — saves + // 2*n_kv-2 launches per (layer, d) pair on a dedicated stream. + if (walked_sibling) { + const int base = last_tree_base_pos_; + const int n_full_attn = (int)cache_.attn_k.size(); + for (int d = 1; d < commit_n; d++) { + const int src_dfs = accepted_dfs[d]; + const int dst_slot = d; + if (src_dfs == dst_slot) continue; + for (int l = 0; l < n_full_attn; l++) { + ggml_tensor * ck = cache_.attn_k[l]; + ggml_tensor * cv = cache_.attn_v[l]; + if (!ck || !cv) continue; + const size_t slot_bytes = ck->nb[1]; + const int n_kv = (int)ck->ne[2]; + const size_t pitch = ck->nb[2]; + const size_t src_off = (size_t)(base + src_dfs) * slot_bytes; + const size_t dst_off = (size_t)(base + dst_slot) * slot_bytes; + cudaMemcpy2DAsync((char *)ck->data + dst_off, pitch, + (const char *)ck->data + src_off, pitch, + slot_bytes, n_kv, + cudaMemcpyDeviceToDevice, stream); + cudaMemcpy2DAsync((char *)cv->data + dst_off, cv->nb[2], + (const char *)cv->data + src_off, cv->nb[2], + slot_bytes, (int)cv->ne[2], + cudaMemcpyDeviceToDevice, stream); + } + } + } + + // Sync rollback stream so the next graph_compute (on default stream) + // sees a consistent KV/SSM state. + if (rollback_stream_) cudaStreamSynchronize(rollback_stream_); + + // Advance cur_pos to "just past the last committed slot" so the next + // verify_batch's kv_start lines up. root = dfs 0 lives at base, so + // commit_n committed tokens occupy slots [base..base+commit_n-1]. + cache_.cur_pos = last_tree_base_pos_ + commit_n; + return true; +} + +void Qwen35DFlashTarget::capture_topology_for_chain(int n_tokens, int base_pos) { + if (n_tokens <= 0) { last_tree_base_pos_ = -1; return; } + last_tree_base_pos_ = base_pos; + last_tree_n_nodes_ = n_tokens - 1; + last_tree_parents_.resize(n_tokens); + last_tree_depths_.resize((size_t)(n_tokens - 1)); + last_tree_parents_[0] = -1; + for (int i = 1; i < n_tokens; i++) { + last_tree_parents_[i] = i - 1; + last_tree_depths_[i - 1] = i; + } +} + +bool Qwen35DFlashTarget::restore_kv_at_chain(int accept_n) { + // A chain of N tokens recorded by verify_batch is the DFS spine + // [0, 1, ..., N-1]. Roll back to slot accept_n: the first (accept_n + 1) + // positions remain committed, the tail is discarded. Returns false if + // the most recent verify_batch did NOT capture per-position intermediates + // (chain_capture_enabled_ was off, or n_tokens overflowed the cache) — + // the chain runner falls back to its legacy snapshot+recommit path. + if (accept_n < 0) return false; + if (last_tree_base_pos_ < 0) return false; + if (accept_n > last_tree_n_nodes_) return false; + std::vector path((size_t)accept_n + 1); + for (int i = 0; i <= accept_n; i++) path[i] = i; + return restore_kv_at_dfs(path); +} + bool Qwen35DFlashTarget::snapshot_kv() { snapshot_ssm_state(cache_); return true; @@ -141,6 +672,32 @@ bool Qwen35DFlashTarget::project_hidden_to_tokens( return true; } +bool Qwen35DFlashTarget::project_hidden_to_logits( + const float * hidden, + int n_tokens, + std::vector & logits_out, + int & out_vocab) { + out_vocab = 0; + if (n_tokens <= 0) return false; + + if (!build_lm_head_projection_step(proj_sg_, w_, backend_, n_tokens)) { + return false; + } + + ggml_backend_tensor_set(proj_sg_.hidden_input, hidden, 0, + sizeof(float) * (size_t)n_tokens * w_.n_embd); + + auto st = ggml_backend_graph_compute(backend_, proj_sg_.gf); + if (st != GGML_STATUS_SUCCESS) return false; + + const int vocab = (int)proj_sg_.logits->ne[0]; + logits_out.resize((size_t)n_tokens * vocab); + ggml_backend_tensor_get(proj_sg_.logits, logits_out.data(), 0, + sizeof(float) * (size_t)n_tokens * vocab); + out_vocab = vocab; + return true; +} + int Qwen35DFlashTarget::mask_token_id() const { return w_.mask_token_id; } diff --git a/dflash/src/qwen35/qwen35_dflash_target.h b/dflash/src/qwen35/qwen35_dflash_target.h index 6a72e48b..515277f4 100644 --- a/dflash/src/qwen35/qwen35_dflash_target.h +++ b/dflash/src/qwen35/qwen35_dflash_target.h @@ -7,12 +7,14 @@ #pragma once #include "common/dflash_target.h" +#include "common/ddtree.h" #include "internal.h" // TargetWeights, TargetCache, DraftWeights #include "step_graph.h" #include "graph_builders.h" #include "ggml.h" #include "ggml-backend.h" +#include "device_runtime.h" // cudaStream_t #include @@ -30,6 +32,63 @@ class Qwen35DFlashTarget : public DFlashTarget { ~Qwen35DFlashTarget() override; + ggml_backend_t backend() const override { return backend_; } + + // Phase B+ fused-LM-head path: the Qwen3.6 MTP head's step graph can + // append `mul_mat(w_.output, x_normed) -> argmax` directly so it + // avoids a hidden -> host -> separate-cgraph round trip per call. + ggml_tensor * lm_head_weight() const override { return w_.output; } + + // Mirror the causal window onto MTP head's flash-attn so it sees the + // same active context as the target's full-attention blocks. + int fa_window() const override { return fa_window_; } + + // Enable per-position post-norm hidden capture during verify_batch. + // Off by default; MTP modules that depend on hidden_at_pos() flip it on + // in attach() — and ALSO pin capture_mode_ to FULL_SEQ so the runtime + // toggle below cannot demote it to LAST_ROW_ONLY (which captures the + // wrong row for partial-accept iterations of the MTP chain and silently + // returns null from hidden_at_pos_pre_norm). + void enable_hidden_seq_capture(bool on) override { + capture_hidden_seq_ = on; + if (on) { + capture_mode_ = VerifyCaptureMode::FULL_SEQ; + capture_pinned_ = true; + } else { + capture_pinned_ = false; + } + } + + // Hidden-sequence capture granularity. FULL_SEQ downloads the entire + // [n_tokens, n_embd] post-norm + pre-norm hidden tensors device->host + // on every verify_batch (needed by warm_head_kv during prefill, which + // consumes per-position hiddens via last_hidden_seq()). LAST_ROW_ONLY + // downloads only row n_tokens-1 — sufficient for decode-side chain + // verifies whose only consumer is hidden_at_pos(base_pos - 1), i.e. the + // last token of the just-verified batch. Switching to LAST_ROW_ONLY + // after prefill collapses two ~80 KB device->host syncs (post-norm + + // pre-norm, hidden_dim=5120, D+1=4 tokens) into two ~20 KB single-row + // syncs and saves the 2x WSL2 scheduler latency hit per verify. + enum class VerifyCaptureMode { + FULL_SEQ, // default — required during prefill / warm_head_kv + LAST_ROW_ONLY, // decode mode — only hidden_at_pos(base_pos-1) used + }; + // Runtime toggle is a NO-OP once MTP has pinned capture to FULL_SEQ. This + // makes the partial-accept-returns-null bug impossible by construction — + // non-MTP callers can still use this freely; MTP callers cannot demote it. + void set_hidden_capture_mode(VerifyCaptureMode mode) { + if (capture_pinned_) return; + capture_mode_ = mode; + } + VerifyCaptureMode hidden_capture_mode() const { return capture_mode_; } + + void set_hidden_capture_scope(DFlashTarget::VerifyCaptureScope scope) override { + if (capture_pinned_) return; + capture_mode_ = (scope == DFlashTarget::VerifyCaptureScope::LAST_ROW_ONLY) + ? VerifyCaptureMode::LAST_ROW_ONLY + : VerifyCaptureMode::FULL_SEQ; + } + // ── DFlashTarget interface ────────────────────────────────────── bool verify_batch(const std::vector & tokens, @@ -37,9 +96,42 @@ class Qwen35DFlashTarget : public DFlashTarget { int & last_tok, std::vector * all_argmax = nullptr) override; + // Tree-verify override. Stage 1 stub: only handles the degenerate + // single-token case (tree.n_nodes == 0) by dispatching to verify_batch. + // For real DDTree shapes (n_nodes > 0) it returns false so the harness + // falls back to chain-verify. Stage 2 will wire build_target_step_tree + // + ancestor mask here. + bool verify_tree(const std::vector & flat_tokens, + const DDTree & tree, + int base_pos, + std::vector & out_argmax, + std::vector * out_logits = nullptr) override; + bool snapshot_kv() override; bool restore_kv() override; + // Rollback DeltaNet SSM/conv + full-attn KV to the accepted-path tail of + // the most recent verify_tree() call. Dequantizes per-DFS-slot SSM + // snapshots; compacts full-attn KV when accepted walk deviates from DFS spine. + // Postcondition: cache_.cur_pos = root_base_pos + commit_n. + // Returns false if verify_tree was not called or accepted_dfs is out of range. + bool restore_kv_at_dfs(const std::vector & accepted_dfs) override; + + // Chain-mode rollback: builds spine [0..accept_n] and dispatches to + // restore_kv_at_dfs. Requires chain capture enabled on the prior verify_batch. + // Returns false if capture was off; caller falls back to snapshot+recommit. + bool restore_kv_at_chain(int accept_n) override; + + // Enable per-position DeltaNet intermediate capture in verify_batch. + // Off by default; unsafe when n_tokens > max_verify_tokens. + // When on, verify_batch populates ssm_intermediate/conv_input_cache and + // records a linear-chain topology for restore_kv_at_chain. + void enable_chain_capture(bool on) override { chain_capture_enabled_ = on; } + + // Record linear-chain topology (spine [0..n_tokens-1] at base_pos) so + // restore_kv_at_chain() can locate the rollback slot. + void capture_topology_for_chain(int n_tokens, int base_pos) override; + bool is_eos(int token) const override; bool embed_tokens(const int32_t * tokens, int n, @@ -49,10 +141,55 @@ class Qwen35DFlashTarget : public DFlashTarget { int n_tokens, std::vector & tokens_out) override; + bool project_hidden_to_logits(const float * hidden, + int n_tokens, + std::vector & logits_out, + int & out_vocab) override; + int hidden_size() const override { return w_.n_embd; } int mask_token_id() const override; const std::vector & capture_layer_ids() const override; + // Return the backbone's final post-norm hidden state for the last committed + // token (n_embd floats, F32). Populated by the most recent verify_batch. + // Returns nullptr if verify_batch has not been called yet. + const float * last_hidden() const override { return last_hidden_cpu_.empty() ? nullptr : last_hidden_cpu_.data(); } + + // Full post-norm hidden sequence from the last verify_batch, laid out as + // [token_0_hidden, ..., token_{n_tokens-1}_hidden], n_tokens * n_embd floats. + const float * last_hidden_seq(int * out_n_tokens) const override { + if (out_n_tokens) *out_n_tokens = last_hidden_seq_n_; + return last_hidden_seq_cpu_.empty() ? nullptr : last_hidden_seq_cpu_.data(); + } + + // Absolute-position accessor: returns &last_hidden_seq[(pos - chunk_start) * n_embd] + // if pos is within [last_verify_chunk_start_, last_verify_chunk_start_ + n_seq). + const float * hidden_at_pos(int abs_pos) const override { + const int rel = abs_pos - last_verify_chunk_start_; + if (rel < 0 || rel >= last_hidden_seq_n_ || + last_hidden_seq_cpu_.empty()) { + return nullptr; + } + return last_hidden_seq_cpu_.data() + (size_t)rel * w_.n_embd; + } + + // Pre-final-output-norm variant of hidden_at_pos. Mirrors llama.cpp + // PR #22673's `t_h_pre_norm`: the Qwen3.6 MTP head's hnorm normalises + // h_prev internally, so feeding it the post-output-norm tensor double- + // normalises and compounds per-depth rejection. The chain GPU loop's + // outer h_prev_0 seed must use this accessor; the intra-iter re-feed + // already uses the MTP graph's pre-norm output (state_->last_hidden). + // Returns nullptr if the most recent verify_batch did NOT capture the + // pre-norm sequence (enable_hidden_seq_capture(true) flips this on). + const float * hidden_at_pos_pre_norm(int abs_pos) const override { + const int rel = abs_pos - last_verify_chunk_start_; + if (rel < 0 || rel >= last_hidden_seq_n_ || + last_hidden_seq_pre_norm_cpu_.empty()) { + return nullptr; + } + return last_hidden_seq_pre_norm_cpu_.data() + (size_t)rel * w_.n_embd; + } + private: TargetWeights & w_; TargetCache & cache_; @@ -66,6 +203,76 @@ class Qwen35DFlashTarget : public DFlashTarget { // LM-head projection graph (lazily built). StepGraph proj_sg_; + + // CPU-side copy of the last token's post-norm hidden state. + // Filled by verify_batch after each GPU graph_compute. + mutable std::vector last_hidden_cpu_; + + // CPU-side copy of the full [n_tokens, n_embd] post-norm hidden sequence + // from the last verify_batch. Used by Qwen3.6 MTP warm_head_kv to seed + // the head's per-position K/V cache during prefill. + mutable std::vector last_hidden_seq_cpu_; + // CPU-side copy of the full [n_tokens, n_embd] PRE-final-output-norm + // hidden sequence — populated alongside last_hidden_seq_cpu_ when + // capture_hidden_seq_ is on. Consumed by hidden_at_pos_pre_norm() so + // the Qwen3.6 MTP chain's outer h_prev_0 seed dodges double-normalisation. + mutable std::vector last_hidden_seq_pre_norm_cpu_; + mutable int last_hidden_seq_n_ = 0; + // Absolute position of the FIRST token captured in last_hidden_seq_cpu_. + // Used by hidden_at_pos to translate an absolute sequence position to an + // index inside the captured chunk. -1 means no captured chunk yet. + mutable int last_verify_chunk_start_ = -1; + + // Whether verify_batch should request the full post-norm hidden sequence + // and copy it to the host. Toggled by enable_hidden_seq_capture(). + bool capture_hidden_seq_ = false; + + // Granularity of the device->host download when capture_hidden_seq_ is on. + // Toggled by set_hidden_capture_mode(). See VerifyCaptureMode docs. + VerifyCaptureMode capture_mode_ = VerifyCaptureMode::FULL_SEQ; + + // Pinned by enable_hidden_seq_capture(true) — once an MTP module attaches + // and enables capture, the mode cannot be demoted to LAST_ROW_ONLY (that + // captures the wrong row for partial-accept chain iters and silently + // returns null from hidden_at_pos_pre_norm). + bool capture_pinned_ = false; + +#ifdef DFLASH_VERIFY_PROFILE + // Per-instance accumulators: summed wall-clock (ms) per verify_batch call; + // dumped from destructor. Zero-cost when flag is off. + mutable double vprof_sum_set_ = 0.0; + mutable double vprof_sum_compute_ = 0.0; + mutable double vprof_sum_get_hidden_ = 0.0; + mutable double vprof_sum_get_hpre_ = 0.0; + mutable double vprof_sum_get_argmax_ = 0.0; + mutable double vprof_sum_total_ = 0.0; + mutable long long vprof_n_calls_ = 0; +#endif // DFLASH_VERIFY_PROFILE + + // ── Stage 3: state captured by verify_tree for restore_kv_at_dfs() ── + // base_pos of the most recent verify_tree call (= root slot in KV). + int last_tree_base_pos_ = -1; + // n_nodes of the most recent verify_tree. + int last_tree_n_nodes_ = 0; + // Copies of tree.parents/depths from the most recent verify_tree so + // rollback's conv-window walk (which traverses ancestry, not DFS) is + // self-contained. + std::vector last_tree_parents_; + std::vector last_tree_depths_; + + // When true, verify_batch will populate per-position DeltaNet + // ssm_intermediate / conv_input_cache buffers AND record a linear-chain + // topology in last_tree_* so restore_kv_at_chain() can dispatch to + // restore_kv_at_dfs() on partial-accept. Toggled by the MTP chain runner + // via enable_chain_capture(); off by default (capture is unsafe on + // n_tokens > max_verify_tokens, e.g. 512-token prefill chunks, where the + // in-graph ggml_view_3d into the conv_input cache asserts). + bool chain_capture_enabled_ = false; + + // Dedicated CUDA stream for restore_kv_at_dfs copies (bug #3): avoids + // serializing ~384 per-head launches on the default stream. Created on + // first use, destroyed in the dtor. + mutable cudaStream_t rollback_stream_ = nullptr; }; } // namespace dflash::common diff --git a/dflash/src/qwen35/qwen35_mtp.cpp b/dflash/src/qwen35/qwen35_mtp.cpp new file mode 100644 index 00000000..10df5ce1 --- /dev/null +++ b/dflash/src/qwen35/qwen35_mtp.cpp @@ -0,0 +1,2099 @@ +// qwen35_mtp.cpp — see qwen35_mtp.h for contract. +// +// Implements the full DeepSeek-V3 NextN per-head forward +// per Eq 21-23: RMSNorm → eh_proj([hnorm(h_prev); enorm(embed)]) → head-owned +// TRMBlock (Q/K/V w/ QK-norm + RoPE + GQA → attn_output residual → SwiGLU FFN +// residual) → shared_head_norm → shared LM head projection. +// +// Implementation: Path A — CPU host floats. All per-head tensors are +// dequantized to float once via tensor_to_floats() and matvec / RMSNorm / +// SiLU / RoPE / softmax are hand-rolled on the host. Per-step cost is a few +// hundred ms for n_embd=5120, n_ffn=17408 (acceptable for proof-of-correctness; +// GPU migration is a follow-up PR). +// +// The MTP head's attn_q tensor is packed Q+gate (same convention as backbone's +// full-attention blocks): first q_dim elements = Q, last q_dim = gate. The gate +// is passed through sigmoid and multiplied into the attention output before the +// attn_output projection. This matches the backbone forward at blk.63. +// +// RoPE: standard rotary at rope_dimension_count=64 out of head_dim=256, using +// rope_theta=1e7 (qwen35.rope.freq_base from GGUF). For γ_max=1 the draft +// position is base_pos + 0 = base_pos. +// +// Fallback mode: define MTP_PHASE_A_FALLBACK to bypass the TRMBlock and +// use only eh_proj+shared_head_norm+lm_head (useful if a smaller/synthetic GGUF +// lacks the transformer-block tensors). The default path requires all attn/ffn +// tensors to be non-null. + +#include "qwen35_mtp.h" +#include "qwen35_mtp_graph.h" + +#include "common/dflash_target.h" +#include "common/gguf_mmap.h" +#include "qwen35/qwen35_dflash_target.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "gguf.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace dflash::common::mtp { + +#ifdef DFLASH_MTP_PROFILE +// Per-iter profiler for the step_chain_gpu_ loop. Enabled by -DDFLASH_MTP_PROFILE=1. +// Uses host-side wall-clock (same latency as cudaEvents because every ggml +// backend call internally calls cudaStreamSynchronize). +namespace { +inline bool mtp_profile_enabled() { + static const bool on = (std::getenv("DFLASH_MTP_PROFILE") != nullptr); + return on; +} + +using prof_clock = std::chrono::steady_clock; +inline double prof_ms_since(prof_clock::time_point t0) { + auto t1 = prof_clock::now(); + return std::chrono::duration(t1 - t0).count(); +} +} // namespace +#endif // DFLASH_MTP_PROFILE + + +// ── Internal helpers ────────────────────────────────────────────────────── + +namespace { + +// RMSNorm: out[i] = x[i] / rms(x) * weight[i] +// All operations in-place on a separate output buffer. +void rmsnorm_cpu(const float * x, + const float * weight, + float * out, + int n, + float eps = 1e-6f) { + float ss = 0.0f; + for (int i = 0; i < n; i++) ss += x[i] * x[i]; + const float rms_inv = 1.0f / std::sqrt(ss / n + eps); + for (int i = 0; i < n; i++) { + // ggml weight tensors store as F32 rows; safe to cast for small dims + out[i] = x[i] * rms_inv * weight[i]; + } +} + +// Read a ggml_tensor's data as a flat float array. Returns false if the +// tensor is null or its type is not GGML_TYPE_F32. +bool tensor_to_floats(const ggml_tensor * t, + std::vector & out) { + if (!t) return false; + const size_t n = ggml_nelements(t); + out.resize(n); + + // Stage the tensor's raw bytes to a host buffer. Tensors backed by a + // backend buffer (CPU or CUDA) require ggml_backend_tensor_get to copy + // host-side; bare tensors created by tests (no buffer, raw host pointer + // assigned to t->data) can be read directly. + const size_t total_bytes = ggml_nbytes(t); + std::vector staging(total_bytes); + const uint8_t * src_bytes = nullptr; + if (t->buffer) { + ggml_backend_tensor_get(const_cast(t), + staging.data(), 0, total_bytes); + src_bytes = staging.data(); + } else if (t->data) { + src_bytes = static_cast(t->data); + } else { + return false; + } + + if (t->type == GGML_TYPE_F32) { + std::memcpy(out.data(), src_bytes, n * sizeof(float)); + return true; + } + + const ggml_type_traits * tr = ggml_get_type_traits(t->type); + if (!tr || !tr->to_float) return false; + const int64_t row_len = t->ne[0]; + if (row_len <= 0 || n % (size_t)row_len != 0) return false; + const int64_t n_rows = (int64_t)n / row_len; + const size_t row_bytes = ggml_row_size(t->type, row_len); + for (int64_t r = 0; r < n_rows; r++) { + tr->to_float(src_bytes + (size_t)r * row_bytes, + out.data() + (size_t)r * row_len, + row_len); + } + return true; +} + +// Matrix-vector multiply: y = A @ x +// A is [rows x cols] stored row-major (rows = out_dim, cols = in_dim). +void matvec_cpu(const float * A, + const float * x, + float * y, + int rows, + int cols) { + for (int r = 0; r < rows; r++) { + float acc = 0.0f; + const float * row = A + (size_t)r * cols; + for (int c = 0; c < cols; c++) acc += row[c] * x[c]; + y[r] = acc; + } +} + +// Argmax over a float vector; returns index of max element. +int32_t argmax(const float * logits, int n) { + int32_t best = 0; + for (int i = 1; i < n; i++) { + if (logits[i] > logits[best]) best = i; + } + return best; +} + +// Fill StepOutput.topk_logprobs / topk_ids with the K highest log-softmax +// entries from `logits` (length n_vocab), sorted DESCENDING by logprob. +// Uses partial_sort over (logprob, id) pairs — O(n_vocab + K log K) is +// trivial vs the per-head TRMBlock forward and avoids a full sort. +void emit_topk_logprobs(const float * logits, int n_vocab, int K, + mtp::StepOutput & out) { + K = std::min(K, n_vocab); + if (K <= 0) return; + + // log-softmax: stable via max-shift + logsumexp. + float max_l = logits[0]; + for (int i = 1; i < n_vocab; i++) if (logits[i] > max_l) max_l = logits[i]; + double denom = 0.0; + for (int i = 0; i < n_vocab; i++) denom += std::exp((double)(logits[i] - max_l)); + const float log_denom = (float)std::log(denom) + max_l; + + // Pair (logprob, id); partial_sort top-K descending. + std::vector> scratch; + scratch.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + scratch.emplace_back(logits[i] - log_denom, (int32_t)i); + } + std::partial_sort(scratch.begin(), scratch.begin() + K, scratch.end(), + [](const auto & a, const auto & b) { + return a.first > b.first; + }); + + out.topk_logprobs.resize(K); + out.topk_ids.resize(K); + for (int i = 0; i < K; i++) { + out.topk_logprobs[i] = scratch[i].first; + out.topk_ids[i] = scratch[i].second; + } +} + +// Per-head RMSNorm: apply rmsnorm_cpu to each n_per_head-element slice of `x` +// (total dim = n_heads_total * n_per_head) using corresponding weight slice. +// Weight tensor `w` has shape [n_per_head] (same weight shared across all heads +// when called for Q-norm / K-norm). +void per_head_rmsnorm(float * x, + const float * w, + int n_heads_total, + int n_per_head, + float eps = 1e-6f) { + for (int h = 0; h < n_heads_total; h++) { + float * slice = x + (size_t)h * n_per_head; + float ss = 0.0f; + for (int i = 0; i < n_per_head; i++) ss += slice[i] * slice[i]; + const float inv = 1.0f / std::sqrt(ss / n_per_head + eps); + for (int i = 0; i < n_per_head; i++) { + slice[i] = slice[i] * inv * w[i]; + } + } +} + +// SiLU: silu(x) = x * sigmoid(x) = x / (1 + exp(-x)) +inline float silu(float x) { + return x / (1.0f + std::exp(-x)); +} + +// Apply standard rotary position embedding (RoPE) in-place to a flat +// [n_heads * head_dim] buffer. Only the first `n_rot` elements of each head +// are rotated (n_rot <= head_dim). The remaining elements pass through. +// freq_base: rope theta (e.g. 1e7). position: absolute sequence position. +void rope_cpu(float * x, + int n_heads, + int head_dim, + int n_rot, // number of dims to rotate (e.g. 64) + int position, + float freq_base) { + // n_rot must be even (pairs of dims rotated together). + const int half = n_rot / 2; + for (int h = 0; h < n_heads; h++) { + float * head = x + (size_t)h * head_dim; + for (int i = 0; i < half; i++) { + const float theta = (float)position / + std::pow(freq_base, (float)(2 * i) / (float)n_rot); + const float cos_t = std::cos(theta); + const float sin_t = std::sin(theta); + const float x0 = head[i]; + const float x1 = head[i + half]; + head[i] = x0 * cos_t - x1 * sin_t; + head[i + half] = x0 * sin_t + x1 * cos_t; + } + } +} + +// Multi-slot scaled dot-product attention with causal masking and GQA. +// +// q : [n_head * head_dim] — queries at the current draft position +// k_cache : [n_slots * n_head_kv * head_dim] — K cache, slot-major +// v_cache : [n_slots * n_head_kv * head_dim] — V cache, slot-major +// out : [n_head * head_dim] — attention output (GQA-expanded) +// +// Attends over slots [0, n_slots). No explicit causal mask is needed because +// the caller passes only the slots representing positions <= the current +// draft position. Score = (Q · K) / sqrt(head_dim). +void range_attention(const float * q, + const float * k_cache, + const float * v_cache, + float * out, + int n_head, + int n_head_kv, + int head_dim, + int n_slots) { + if (n_slots <= 0) { + std::memset(out, 0, sizeof(float) * (size_t)n_head * head_dim); + return; + } + const int group = n_head / n_head_kv; + const float scale = 1.0f / std::sqrt((float)head_dim); + std::vector scores(n_slots); + for (int qh = 0; qh < n_head; qh++) { + const int kvh = qh / group; + const float * qhead = q + (size_t)qh * head_dim; + for (int s = 0; s < n_slots; s++) { + const float * khead = + k_cache + ((size_t)s * n_head_kv + kvh) * head_dim; + float acc = 0.0f; + for (int i = 0; i < head_dim; i++) acc += qhead[i] * khead[i]; + scores[s] = acc * scale; + } + // Stable softmax. + float max_s = scores[0]; + for (int s = 1; s < n_slots; s++) if (scores[s] > max_s) max_s = scores[s]; + float denom = 0.0f; + for (int s = 0; s < n_slots; s++) { + scores[s] = std::exp(scores[s] - max_s); + denom += scores[s]; + } + const float inv_denom = (denom > 0.0f) ? (1.0f / denom) : 0.0f; + float * ohead = out + (size_t)qh * head_dim; + std::memset(ohead, 0, sizeof(float) * head_dim); + for (int s = 0; s < n_slots; s++) { + const float w = scores[s] * inv_denom; + const float * vhead = + v_cache + ((size_t)s * n_head_kv + kvh) * head_dim; + for (int i = 0; i < head_dim; i++) ohead[i] += w * vhead[i]; + } + } +} + +bool append_tensor(std::vector & tensors, + ggml_tensor * t) { + if (!t || t->data) return true; + if (std::find(tensors.begin(), tensors.end(), t) == tensors.end()) { + tensors.push_back(t); + } + return true; +} + +bool materialize_mtp_tensors(const std::string & gguf_path, + const Qwen35MtpWeights & weights, + ggml_backend_buffer_type_t target_buft, + ggml_backend_buffer_t & out_buf, + std::string & out_error) { + std::vector tensors; + for (const auto & h : weights.heads) { + // NextN-specific tensors + append_tensor(tensors, h.eh_proj); + append_tensor(tensors, h.enorm); + append_tensor(tensors, h.hnorm); + append_tensor(tensors, h.shared_head_norm); + // shared_head_head can be vocab-sized; leave it unmaterialized so + // production uses the target's GPU lm_head projection fallback. + + // Shape B: head-owned transformer-block tensors (all 11 required). + append_tensor(tensors, h.attn_norm); + append_tensor(tensors, h.attn_q); + append_tensor(tensors, h.attn_q_norm); + append_tensor(tensors, h.attn_k); + append_tensor(tensors, h.attn_k_norm); + append_tensor(tensors, h.attn_v); + append_tensor(tensors, h.attn_output); + append_tensor(tensors, h.post_attention_norm); + append_tensor(tensors, h.ffn_gate); + append_tensor(tensors, h.ffn_up); + append_tensor(tensors, h.ffn_down); + } + if (tensors.empty()) return true; + + ggml_backend_buffer_type_t buft = target_buft + ? target_buft : ggml_backend_cpu_buffer_type(); + const size_t alignment = ggml_backend_buft_get_alignment(buft); + size_t total = 0; + std::vector offsets; + offsets.reserve(tensors.size()); + for (ggml_tensor * t : tensors) { + const size_t r = total % alignment; + if (r != 0) total += alignment - r; + offsets.push_back(total); + total += ggml_backend_buft_get_alloc_size(buft, t); + } + + ggml_backend_buffer_t buf = ggml_backend_buft_alloc_buffer(buft, total); + if (!buf) { + out_error = "qwen35_mtp: failed to allocate CPU buffer for MTP tensors"; + return false; + } + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + + char * base = static_cast(ggml_backend_buffer_get_base(buf)); + for (size_t i = 0; i < tensors.size(); i++) { + if (ggml_backend_tensor_alloc(buf, tensors[i], base + offsets[i]) != GGML_STATUS_SUCCESS) { + ggml_backend_buffer_free(buf); + out_error = "qwen35_mtp: ggml_backend_tensor_alloc failed"; + return false; + } + } + + gguf_init_params gp{}; + gp.no_alloc = true; + gp.ctx = nullptr; + gguf_context * gguf = gguf_init_from_file(gguf_path.c_str(), gp); + if (!gguf) { + ggml_backend_buffer_free(buf); + out_error = "qwen35_mtp: gguf_init_from_file failed for " + gguf_path; + return false; + } + + dflash::common::GgufMmap mm; + std::string mmap_error; + if (!mm.open(gguf_path, mmap_error)) { + gguf_free(gguf); + ggml_backend_buffer_free(buf); + out_error = mmap_error; + return false; + } + + const size_t data_start = gguf_get_data_offset(gguf); + for (ggml_tensor * t : tensors) { + const char * name = ggml_get_name(t); + const int64_t tid = gguf_find_tensor(gguf, name); + if (tid < 0) { + gguf_free(gguf); + ggml_backend_buffer_free(buf); + out_error = std::string("qwen35_mtp: tensor missing from GGUF: ") + name; + return false; + } + const size_t off = data_start + gguf_get_tensor_offset(gguf, tid); + const size_t sz = gguf_get_tensor_size(gguf, tid); + if (off + sz > mm.size()) { + gguf_free(gguf); + ggml_backend_buffer_free(buf); + out_error = std::string("qwen35_mtp: tensor overflows GGUF: ") + name; + return false; + } + ggml_backend_tensor_set(t, + static_cast(mm.data()) + off, + 0, sz); + } + + gguf_free(gguf); + out_buf = buf; + return true; +} + +ggml_context * load_gguf_tensor_context(const std::string & gguf_path, + std::string & out_error) { + ggml_context * ctx = nullptr; + gguf_init_params gp{}; + gp.no_alloc = true; + gp.ctx = &ctx; + gguf_context * gguf = gguf_init_from_file(gguf_path.c_str(), gp); + if (!gguf || !ctx) { + if (gguf) gguf_free(gguf); + out_error = "qwen35_mtp: failed to create GGUF tensor context for " + gguf_path; + return nullptr; + } + gguf_free(gguf); + return ctx; +} + +} // anonymous namespace + +// ── State ───────────────────────────────────────────────────────────────── + +// Per-head KV cache buffer. +// GPU mode: ggml_tensors on the backbone backend with layout +// [head_dim, n_ctx, n_head_kv] (matches backbone cache_k / cache_v +// so flash_attn views slice cleanly along dim 1). +// CPU mode (T1 stubs / no-backend fallback): fp32 host vectors with +// layout [n_slot * n_head_kv * head_dim] (slot-major). +struct HeadKvBuffer { + ggml_tensor * k_cache = nullptr; + ggml_tensor * v_cache = nullptr; + std::vector k; + std::vector v; +}; + +struct Qwen35MtpModule::State { + Qwen35MtpWeights weights; + DFlashTarget * target = nullptr; + ggml_context * owned_ctx = nullptr; + ggml_backend_buffer_t mtp_buf = nullptr; + bool loaded = false; + bool attached = false; + + // last_hidden is zeroed on init / reset_chain; carries the backbone's + // post-norm hidden from the last committed step once set_initial_hidden is called. + std::vector last_hidden; // length == weights.n_embd when loaded + + // set_initial_hidden state: stash pointer + dim from the backbone caller. + // Consumed by the Shape B TRMBlock forward. + // Pointer is NOT owned; it must remain valid for the duration of step_batch. + const float * initial_hidden_ptr = nullptr; + int initial_hidden_dim = 0; + + // Per-head KV cache buffers, sized n_ctx slots × n_head_kv × key/val_length. + // GPU path: backed by ggml tensors on backbone backend (kv_ctx + kv_buf). + // warm_head_kv() fills slots [1, n_prompt] post-prefill; each step_batch() + // call writes its draft K/V at slot (base_pos + h) inside the cgraph. + std::vector head_kv; + int n_ctx = 0; // allocated slots per head + + // GPU-side head_kv tensor lifetimes. + ggml_context * kv_ctx = nullptr; + ggml_backend_buffer_t kv_buf = nullptr; + + // Bug #5 fix: graphs are shape-only. Slot write + FA read slots/mask + // are runtime inputs, so a single graph services every draft_pos for + // a given (head_idx, fa_window, fused_lm_head, topk_k). Cap at 4 + // entries — head_kv is single-head in production and target config + // is stable, so the LRU collapses to 1 entry in practice. + struct StepGraphKey { + int head_idx = -1; + int fa_window = 0; + bool fused_lm_head = false; + int topk_k = 0; + }; + std::array>, 4> step_sg_cache{}; + // Single scratch graph for the deprecated path (callers that pass the + // legacy build_qwen35_mtp_step_graph signature with no caching). + Qwen35MtpStepGraph step_sg; + // Cached warmup graph (rebuilt per call because n_tokens varies). + Qwen35MtpWarmGraph warm_sg; + + // Per-head top-K logprob emission. K=1 means argmax-only (legacy ABI: + // StepOutput.topk_* stays empty). K>1 populates the topk surface in + // every emitted StepOutput. Configured once via set_draft_topk(); + // persists across reset_chain() since the bench/harness toggles K at + // setup time, not per chain. + int draft_topk = 1; +}; + +// ── Lifecycle ───────────────────────────────────────────────────────────── + +Qwen35MtpModule::Qwen35MtpModule() : state_(std::make_unique()) {} +Qwen35MtpModule::~Qwen35MtpModule() { shutdown(); } + +bool Qwen35MtpModule::init(const std::string & gguf_path, + DFlashTarget * target, + std::string & out_error) { + shutdown(); + ggml_context * ctx = load_gguf_tensor_context(gguf_path, out_error); + if (!ctx) { + return false; + } + state_->owned_ctx = ctx; + if (!init(gguf_path, ctx, target, out_error)) { + shutdown(); + return false; + } + return true; +} + +bool Qwen35MtpModule::init(const std::string & gguf_path, + ggml_context * ctx, + DFlashTarget * target, + std::string & out_error) { + if (!target) { + out_error = "Qwen35MtpModule::init: target is null"; + return false; + } + if (!ctx) { + out_error = "Qwen35MtpModule::init: ctx is null"; + return false; + } + if (!state_->owned_ctx) { + shutdown(); + } + const bool ok = load_qwen35_mtp_weights( + gguf_path, ctx, + /*expected_n_embd=*/target->hidden_size(), + /*expected_n_vocab=*/0, + state_->weights, out_error); + + if (!ok) return false; + ggml_backend_t tgt_backend = target ? target->backend() : nullptr; + ggml_backend_buffer_type_t tgt_buft = tgt_backend + ? ggml_backend_get_default_buffer_type(tgt_backend) + : ggml_backend_cpu_buffer_type(); + if (!materialize_mtp_tensors(gguf_path, state_->weights, tgt_buft, + state_->mtp_buf, out_error)) { + state_->weights = {}; + return false; + } + + // Hard contract — current GPU paths (warm_head_kv, step_batch_gpu_, + // step_chain) hardcode h=0 and chain only against the first NextN head. + // A GGUF with n_heads>1 (e.g. DeepSeek-V3 NextN can ship up to 3) would + // silently produce wrong drafts: head_kv is allocated for all heads but + // only head 0 is warmed/written. Fail loudly until the multi-head path + // is built and tested. Per momus review, "the one thing nobody checked". + if (state_->weights.n_heads > 1) { + char buf[256]; + std::snprintf(buf, sizeof(buf), + "Qwen35MtpModule::init: n_heads=%d (>1) is not supported. " + "warm_head_kv only initializes head 0 and step_chain reuses " + "head 0 across iters; multi-head GGUFs would silently produce " + "wrong drafts. See momus review.", + state_->weights.n_heads); + out_error = buf; + state_->weights = {}; + return false; + } + + // Per-head KV cache allocation. Two modes: + // - GPU mode (default when target has a backend): allocate ggml tensors + // [head_dim, n_ctx, n_head_kv] on backbone backend so the step cgraph + // can write/read KV in-place via ggml_cpy and ggml_view_3d. + // - CPU mode (no backend, e.g. T1 tests): fp32 host vectors only. + { + const int gamma_max = state_->weights.n_heads; + const int n_head_kv = state_->weights.n_head_kv; + const int key_len = state_->weights.n_key_length; + const int val_len = state_->weights.n_value_length; + // Chain horizon = max prompt+decode positions the MTP head KV can hold. + // Was hardcoded to 8192; overflows on real agentic prompts (Claude Code + // sends 9-24K). Allow override via DFLASH27B_MTP_CTX so the daemon path + // can size this to match target max_ctx. + int n_ctx = 8192; + if (const char * s = std::getenv("DFLASH27B_MTP_CTX")) { + const int v = std::atoi(s); + if (v > 0) n_ctx = v; + } + state_->n_ctx = n_ctx; + + if (n_head_kv > 0 && key_len > 0 && val_len > 0 && gamma_max > 0) { + state_->head_kv.resize(gamma_max); + // The CPU forward path (T1 stub tests, no backend) reads/writes + // the host vectors; the GPU path reads/writes the backend tensors. + // Allocate only the side we will actually use. + if (!tgt_backend) { + for (int h = 0; h < gamma_max; h++) { + state_->head_kv[h].k.assign( + (size_t)n_ctx * n_head_kv * key_len, 0.0f); + state_->head_kv[h].v.assign( + (size_t)n_ctx * n_head_kv * val_len, 0.0f); + } + } else { + const int rb_tensors = 2 * gamma_max; + ggml_init_params kp{}; + kp.mem_size = (size_t)(rb_tensors + 16) * ggml_tensor_overhead(); + kp.mem_buffer = nullptr; + kp.no_alloc = true; + state_->kv_ctx = ggml_init(kp); + if (!state_->kv_ctx) { + out_error = "qwen35_mtp: head_kv ggml_init failed"; + return false; + } + for (int h = 0; h < gamma_max; h++) { + // Head KV stored as F16 on device. CUDA + // ggml_flash_attn_ext takes F16 K/V natively (fattn.cu + // accepts F16/BF16/quant K/V; F32 K/V are auto-cast + // up-front) and ggml_cpy F32 -> F16 is supported inside + // the step graph for the per-step K/V write. Saves 50% + // of the per-head KV footprint and matches the backbone + // cache_k/cache_v dtype. + ggml_tensor * k_t = ggml_new_tensor_3d(state_->kv_ctx, + GGML_TYPE_F16, key_len, n_ctx, n_head_kv); + ggml_tensor * v_t = ggml_new_tensor_3d(state_->kv_ctx, + GGML_TYPE_F16, val_len, n_ctx, n_head_kv); + char name[64]; + std::snprintf(name, sizeof(name), "mtp_head_%d_k", h); + ggml_set_name(k_t, name); + std::snprintf(name, sizeof(name), "mtp_head_%d_v", h); + ggml_set_name(v_t, name); + state_->head_kv[h].k_cache = k_t; + state_->head_kv[h].v_cache = v_t; + } + state_->kv_buf = ggml_backend_alloc_ctx_tensors( + state_->kv_ctx, tgt_backend); + if (!state_->kv_buf) { + ggml_free(state_->kv_ctx); + state_->kv_ctx = nullptr; + state_->head_kv.clear(); + out_error = "qwen35_mtp: ggml_backend_alloc_ctx_tensors for head_kv failed"; + return false; + } + ggml_backend_buffer_clear(state_->kv_buf, 0); + } + } + } + + state_->loaded = true; + // Zero the bootstrap hidden. + state_->last_hidden.assign(state_->weights.n_embd, 0.0f); + // Clear initial_hidden state for this init. + state_->initial_hidden_ptr = nullptr; + state_->initial_hidden_dim = 0; + return attach(target); +} + +int Qwen35MtpModule::max_gamma() const { + // Post-Phase-A semantics: max_gamma is the autoregressive CHAIN depth ceiling, + // not the physical NextN head count. We re-feed the single head's own + // post-shared_head_norm hidden as h_prev to extend the chain to arbitrary depth + // (oracle blocker 5.6 analysis). Capped at 8 to match Unsloth's --spec-draft-n-max + // ceiling and keep the head_kv slot writes within n_ctx=8192. Returns 0 pre-init + // so the basic contract test (max_gamma()==0 before init) still holds. + if (!state_->loaded) return 0; + return 8; +} +int Qwen35MtpModule::hidden_size() const { return state_->weights.n_embd; } +int Qwen35MtpModule::num_heads() const { return state_->weights.n_heads; } + +bool Qwen35MtpModule::attach(DFlashTarget * target) { + if (!target) return false; + if (state_->loaded && target->hidden_size() != state_->weights.n_embd) { + std::fprintf(stderr, + "[qwen35_mtp] hidden_size mismatch (target=%d, mtp=%d)\n", + target->hidden_size(), state_->weights.n_embd); + return false; + } + state_->target = target; + state_->attached = true; + // The MTP forward needs per-position post-norm hiddens for warmup + + // chain-iter h_prev lookup; signal that to the target so it captures + // them. Non-MTP-bound targets pay nothing. + if (auto * t = dynamic_cast(target)) { + t->enable_hidden_seq_capture(true); + } + return true; +} + +void Qwen35MtpModule::reset_chain() { + if (state_->loaded) { + std::fill(state_->last_hidden.begin(), state_->last_hidden.end(), 0.0f); + } + state_->initial_hidden_ptr = nullptr; + state_->initial_hidden_dim = 0; +} + +void Qwen35MtpModule::set_draft_topk(int k) { + state_->draft_topk = (k >= 1) ? k : 1; +} + +void Qwen35MtpModule::shutdown() { + for (auto & e : state_->step_sg_cache) { + if (e.second) qwen35_mtp_step_graph_free(*e.second); + e.second.reset(); + e.first = State::StepGraphKey{}; + } + qwen35_mtp_step_graph_free(state_->step_sg); + qwen35_mtp_warm_graph_free(state_->warm_sg); + if (state_->kv_buf) { + ggml_backend_buffer_free(state_->kv_buf); + state_->kv_buf = nullptr; + } + if (state_->kv_ctx) { + ggml_free(state_->kv_ctx); + state_->kv_ctx = nullptr; + } + // Per-head KV CPU mirrors are std::vector — destructors free them. + state_->head_kv.clear(); + state_->n_ctx = 0; + + if (state_->mtp_buf) { + ggml_backend_buffer_free(state_->mtp_buf); + state_->mtp_buf = nullptr; + } + if (state_->owned_ctx) { + ggml_free(state_->owned_ctx); + state_->owned_ctx = nullptr; + } + state_->target = nullptr; + state_->attached = false; + state_->loaded = false; + state_->weights = {}; + state_->last_hidden.clear(); + state_->initial_hidden_ptr = nullptr; + state_->initial_hidden_dim = 0; +} + +// ── Shape B forward ─────────────────────────────────────────────────────── +// +// For each MTP head h ∈ [0, n_heads): +// +// Step A: h_prev = initial_hidden (k=0) or last_hidden (k>0) +// Step B: embed cur/drafted token +// Step C: Eq 21 — x = eh_proj @ [hnorm(h_prev); enorm(embed)] +// Step D: Eq 22 — TRMBlock_k (head-owned attn + FFN) applied to x +// Step E: Eq 23 — shared_head_norm + shared LM head → draft token +// +// TRMBlock_k uses head-owned tensors (NOT backbone). attn_q is packed Q+gate +// [n_embd, 2*(head_count*key_length)]; gate is sigmoid-multiplied into attn out. +// RoPE: standard (not M-RoPE) at n_rot=rope_dimension_count, theta=1e7. +// KV: single-slot, attending only to the new K/V (trivial for γ_max=1). +// +// When MTP_PHASE_A_FALLBACK is defined, skips the TRMBlock (fallback path). + +bool Qwen35MtpModule::step_batch(int32_t current_token, + int base_pos, + std::vector & out) { + // Guard: module must be loaded and attached. + if (!state_->loaded || !state_->attached) { + out.clear(); + return false; + } + + // GPU path runs whenever a CUDA backend is bound; the CPU forward below + // is reserved for T1 stub tests (attach_weights_for_test, no backend). + if (state_->kv_ctx && state_->target && state_->target->backend()) { + return step_batch_gpu_(current_token, base_pos, out); + } + + const int n_embd = state_->weights.n_embd; + const int n_vocab = state_->weights.n_vocab; + const int n_heads = state_->weights.n_heads; + const int n_head = state_->weights.n_head_count; + const int n_head_kv = state_->weights.n_head_kv; + const int key_len = state_->weights.n_key_length; + const int val_len = state_->weights.n_value_length; + const int ffn_len = state_->weights.n_ffn_length; + + // Packed Q layout: attn_q has 2*q_dim columns (Q || gate). + const int q_dim = n_head * key_len; // 24*256 = 6144 for 27B + const int kv_dim = n_head_kv * key_len; // 4*256 = 1024 + const int v_total = n_head_kv * val_len; + + // RoPE params (Qwen3.6-27B constants from GGUF metadata). + // For the 27B GGUF: rope_dimension_count=64, rope_theta=1e7. + // We use these constants directly since they're part of the verified GGUF + // contract (qwen35_mtp_redesign.md §Verified GGUF Constants). + // For text-mode MROPE, the 3 active axes share the same position, so + // it reduces to NeoX RoPE with n_rot=rope.dimension_count=64. The CPU + // fallback path uses plain rope_cpu(); the GPU graph calls ggml_rope_multi + // with the real sections so it stays correct for multi-axis modes. + const int rope_n_rot = std::min(64, key_len); + const float rope_theta = 1e7f; + + out.clear(); + out.reserve(n_heads); + + // Working buffers. + std::vector embed_buf(n_embd); + std::vector e_in(n_embd); + std::vector h_in(n_embd); + std::vector concat_buf(2 * n_embd); + std::vector x(n_embd); + std::vector x_normed(n_embd); + + // Per-head tensor float caches (dequantized from ggml tensors). + std::vector enorm_data; + std::vector hnorm_data; + std::vector eh_proj_data; + std::vector shared_head_norm_data; + std::vector shared_head_head_data; + +#ifndef MTP_PHASE_A_FALLBACK + std::vector attn_norm_data; + std::vector attn_q_data; // packed: [(Q||gate) x n_embd] + std::vector attn_q_norm_data; + std::vector attn_k_data; + std::vector attn_k_norm_data; + std::vector attn_v_data; + std::vector attn_output_data; + std::vector post_attn_norm_data; + std::vector ffn_gate_data; + std::vector ffn_up_data; + std::vector ffn_down_data; + + // TRMBlock working buffers (sized for worst case; reused across heads). + std::vector q_buf(2 * q_dim); // packed Q+gate from projection + std::vector k_buf(kv_dim); + std::vector v_buf(v_total); + std::vector attn_out_buf(n_head * val_len); + std::vector proj_buf(n_embd); + std::vector ffn_gate_buf(ffn_len); + std::vector ffn_up_buf(ffn_len); + std::vector ffn_in_buf(ffn_len); + std::vector ffn_out(n_embd); +#endif // MTP_PHASE_A_FALLBACK + + std::vector logits_buf; + int32_t cur_token = current_token; + + for (int h = 0; h < n_heads; h++) { + const auto & head = state_->weights.heads[h]; + + // h=0 reads h_{base_pos-1} from the target; h>0 chains use the + // previous head's un-normed output stashed in state_->last_hidden. + const float * h_prev = state_->last_hidden.data(); + if (h == 0) { + bool found_hidden = false; + if (state_->target) { + if (base_pos >= 1) { + // Prefer pre-output-norm (PR #22673 t_h_pre_norm) so + // the head's hnorm doesn't double-normalise. Fall + // back to post-norm if the adapter did not capture + // the pre-norm sequence this verify_batch. + const float * tgt_h = + state_->target->hidden_at_pos_pre_norm(base_pos - 1); + if (!tgt_h) { + tgt_h = state_->target->hidden_at_pos(base_pos - 1); + } + if (tgt_h) { + h_prev = tgt_h; + found_hidden = true; + } + } + if (!found_hidden) { + const float * tgt_h = state_->target->last_hidden(); + if (tgt_h) { + h_prev = tgt_h; + found_hidden = true; + } + } + } + if (!found_hidden && state_->initial_hidden_ptr && + state_->initial_hidden_dim == n_embd) { + h_prev = state_->initial_hidden_ptr; + found_hidden = true; + } + if (!found_hidden) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: no live hidden available at base_pos=%d; " + "using zero vector for h_prev\n", base_pos); + } + } + + // ── Step B: embed current/drafted token ──────────────────────── + const int32_t tok_ids[1] = { cur_token }; + if (!state_->target->embed_tokens(tok_ids, 1, embed_buf.data())) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: embed_tokens failed for head %d\n", h); + out.clear(); + return false; + } + + // ── Step C: Eq 21 — eh_proj([hnorm(h_prev); enorm(embed)]) ─── + + // Load enorm and hnorm + if (!tensor_to_floats(head.enorm, enorm_data) || + (int)enorm_data.size() != n_embd) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid enorm at head %d\n", h); + out.clear(); + return false; + } + if (!tensor_to_floats(head.hnorm, hnorm_data) || + (int)hnorm_data.size() != n_embd) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid hnorm at head %d\n", h); + out.clear(); + return false; + } + + rmsnorm_cpu(embed_buf.data(), enorm_data.data(), e_in.data(), n_embd); + rmsnorm_cpu(h_prev, hnorm_data.data(), h_in.data(), n_embd); + + // Concat order [e_in; h_in] (embed first, hidden second) matches the + // reference llama.cpp PR #22673 graph_mtp: + // `ggml_concat(ctx0, e_norm, h_norm, /*dim=*/0)`. + // The earlier "hidden first" claim in qwen35_mtp_redesign.md was wrong. + std::copy(e_in.begin(), e_in.end(), concat_buf.begin()); + std::copy(h_in.begin(), h_in.end(), concat_buf.begin() + n_embd); + + // Project: x = eh_proj @ concat, shape [n_embd, 2*n_embd] × [2*n_embd] → [n_embd] + if (!tensor_to_floats(head.eh_proj, eh_proj_data) || + (int)eh_proj_data.size() != n_embd * 2 * n_embd) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid eh_proj at head %d " + "(got %zu, expected %d)\n", + h, eh_proj_data.size(), n_embd * 2 * n_embd); + out.clear(); + return false; + } + matvec_cpu(eh_proj_data.data(), concat_buf.data(), x.data(), + n_embd, 2 * n_embd); + +#ifndef MTP_PHASE_A_FALLBACK + // ── Step D: Eq 22 — TRMBlock_k (head-owned attn + FFN) ──────── + // All required tensors must be non-null (validated by loader). + + // Pre-attn norm: cur = RMSNorm(x, attn_norm) + if (!tensor_to_floats(head.attn_norm, attn_norm_data) || + (int)attn_norm_data.size() != n_embd) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid attn_norm at head %d\n", h); + out.clear(); + return false; + } + // We'll use x_normed as the pre-attn cur. + rmsnorm_cpu(x.data(), attn_norm_data.data(), x_normed.data(), n_embd); + + // Q projection (packed Q||gate): [2*q_dim, n_embd] × [n_embd] → [2*q_dim] + // ggml convention: tensor shape [cols, rows], stored row-major (rows × cols). + // attn_q is [n_embd, 2*q_dim] in ggml's ne[] → rows=2*q_dim, cols=n_embd. + if (!tensor_to_floats(head.attn_q, attn_q_data) || + (int)attn_q_data.size() != 2 * q_dim * n_embd) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid attn_q at head %d " + "(got %zu, expected %d)\n", + h, attn_q_data.size(), 2 * q_dim * n_embd); + out.clear(); + return false; + } + // attn_q output is laid out [Q_head_0 | gate_head_0 | Q_head_1 | gate_head_1 | ...] + // per the backbone graph at qwen35_target_graph.cpp:468-486: QG is + // reshaped to [head_dim*2, n_head] and viewed as Q (offset 0) + gate + // (offset head_dim) with stride head_dim*2 between heads. We must + // de-interleave into contiguous Q and gate buffers before QK-norm / + // RoPE / attention (which all assume [n_head, head_dim] layout). + std::vector q_raw(2 * q_dim); + matvec_cpu(attn_q_data.data(), x_normed.data(), q_raw.data(), + 2 * q_dim, n_embd); + std::vector gate_data(q_dim); + q_buf.resize(q_dim); + for (int hd = 0; hd < n_head; hd++) { + const float * src = q_raw.data() + (size_t)hd * 2 * key_len; + std::memcpy(q_buf.data() + (size_t)hd * key_len, src, + sizeof(float) * key_len); + std::memcpy(gate_data.data() + (size_t)hd * key_len, + src + key_len, sizeof(float) * key_len); + } + float * q_ptr = q_buf.data(); + float * gate_ptr = gate_data.data(); + + // K projection: [kv_dim, n_embd] × [n_embd] → [kv_dim] + if (!tensor_to_floats(head.attn_k, attn_k_data) || + (int)attn_k_data.size() != kv_dim * n_embd) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid attn_k at head %d\n", h); + out.clear(); + return false; + } + k_buf.resize(kv_dim); + matvec_cpu(attn_k_data.data(), x_normed.data(), k_buf.data(), + kv_dim, n_embd); + + // V projection: [v_total, n_embd] × [n_embd] → [v_total] + if (!tensor_to_floats(head.attn_v, attn_v_data) || + (int)attn_v_data.size() != v_total * n_embd) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid attn_v at head %d\n", h); + out.clear(); + return false; + } + v_buf.resize(v_total); + matvec_cpu(attn_v_data.data(), x_normed.data(), v_buf.data(), + v_total, n_embd); + + // QK-norm (per-head RMSNorm on Q and K) + if (!tensor_to_floats(head.attn_q_norm, attn_q_norm_data) || + (int)attn_q_norm_data.size() != key_len) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid attn_q_norm at head %d\n", h); + out.clear(); + return false; + } + if (!tensor_to_floats(head.attn_k_norm, attn_k_norm_data) || + (int)attn_k_norm_data.size() != key_len) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid attn_k_norm at head %d\n", h); + out.clear(); + return false; + } + per_head_rmsnorm(q_ptr, attn_q_norm_data.data(), n_head, key_len); + per_head_rmsnorm(k_buf.data(), attn_k_norm_data.data(), n_head_kv, key_len); + + const int draft_pos = base_pos + h; + rope_cpu(q_ptr, n_head, key_len, rope_n_rot, draft_pos, rope_theta); + rope_cpu(k_buf.data(), n_head_kv, key_len, rope_n_rot, draft_pos, rope_theta); + if (draft_pos >= state_->n_ctx || + (int)state_->head_kv.size() <= h) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: draft_pos %d out of head_kv range " + "(n_ctx=%d, head=%d, head_kv_size=%zu)\n", + draft_pos, state_->n_ctx, h, state_->head_kv.size()); + out.clear(); + return false; + } + { + auto & kv = state_->head_kv[h]; + const size_t k_slot_off = (size_t)draft_pos * n_head_kv * key_len; + const size_t v_slot_off = (size_t)draft_pos * n_head_kv * val_len; + std::memcpy(kv.k.data() + k_slot_off, k_buf.data(), + sizeof(float) * (size_t)n_head_kv * key_len); + std::memcpy(kv.v.data() + v_slot_off, v_buf.data(), + sizeof(float) * (size_t)n_head_kv * val_len); + } + + // Range attention over slots [0, draft_pos] of head_kv[h] (causal). + attn_out_buf.resize(n_head * val_len); + range_attention(q_ptr, + state_->head_kv[h].k.data(), + state_->head_kv[h].v.data(), + attn_out_buf.data(), + n_head, n_head_kv, val_len, + /*n_slots=*/draft_pos + 1); + + // Reshape attn output: [n_head * val_len] = [q_dim] (val_len == key_len). + // Apply sigmoid gate from the packed Q. + for (int i = 0; i < q_dim; i++) { + const float g = 1.0f / (1.0f + std::exp(-gate_ptr[i])); + attn_out_buf[i] *= g; + } + + // attn_output projection: [n_embd, q_dim] × [q_dim] → [n_embd] + // attn_output tensor is [head_count*value_length, n_embd] = [q_dim, n_embd] + // in ggml convention: ne[0]=q_dim, ne[1]=n_embd → rows=n_embd, cols=q_dim. + if (!tensor_to_floats(head.attn_output, attn_output_data) || + (int)attn_output_data.size() != n_embd * q_dim) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid attn_output at head %d " + "(got %zu, expected %d)\n", + h, attn_output_data.size(), n_embd * q_dim); + out.clear(); + return false; + } + proj_buf.resize(n_embd); + matvec_cpu(attn_output_data.data(), attn_out_buf.data(), proj_buf.data(), + n_embd, q_dim); + + // Residual: x = x + attn_proj + for (int i = 0; i < n_embd; i++) x[i] += proj_buf[i]; + + // Pre-FFN norm: cur = RMSNorm(x, post_attention_norm) + if (!tensor_to_floats(head.post_attention_norm, post_attn_norm_data) || + (int)post_attn_norm_data.size() != n_embd) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid post_attention_norm at head %d\n", h); + out.clear(); + return false; + } + rmsnorm_cpu(x.data(), post_attn_norm_data.data(), x_normed.data(), n_embd); + + // SwiGLU FFN: ffn_out = ffn_down @ (silu(ffn_gate @ x_n) * (ffn_up @ x_n)) + if (!tensor_to_floats(head.ffn_gate, ffn_gate_data) || + (int)ffn_gate_data.size() != ffn_len * n_embd) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid ffn_gate at head %d\n", h); + out.clear(); + return false; + } + if (!tensor_to_floats(head.ffn_up, ffn_up_data) || + (int)ffn_up_data.size() != ffn_len * n_embd) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid ffn_up at head %d\n", h); + out.clear(); + return false; + } + if (!tensor_to_floats(head.ffn_down, ffn_down_data) || + (int)ffn_down_data.size() != n_embd * ffn_len) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: invalid ffn_down at head %d\n", h); + out.clear(); + return false; + } + + ffn_gate_buf.resize(ffn_len); + ffn_up_buf.resize(ffn_len); + ffn_in_buf.resize(ffn_len); + matvec_cpu(ffn_gate_data.data(), x_normed.data(), ffn_gate_buf.data(), + ffn_len, n_embd); + matvec_cpu(ffn_up_data.data(), x_normed.data(), ffn_up_buf.data(), + ffn_len, n_embd); + for (int i = 0; i < ffn_len; i++) { + ffn_in_buf[i] = silu(ffn_gate_buf[i]) * ffn_up_buf[i]; + } + ffn_out.resize(n_embd); + matvec_cpu(ffn_down_data.data(), ffn_in_buf.data(), ffn_out.data(), + n_embd, ffn_len); + + // Residual: x = x + ffn_out + for (int i = 0; i < n_embd; i++) x[i] += ffn_out[i]; +#endif // MTP_PHASE_A_FALLBACK + + // ── Step E: Eq 23 — shared_head_norm + shared LM head ───────── + + if (head.shared_head_norm && + tensor_to_floats(head.shared_head_norm, shared_head_norm_data) && + (int)shared_head_norm_data.size() == n_embd) { + rmsnorm_cpu(x.data(), shared_head_norm_data.data(), x_normed.data(), n_embd); + } else { + // Fallback: treat as unit weights (norm without scale). + float ss = 0.0f; + for (int i = 0; i < n_embd; i++) ss += x[i] * x[i]; + const float rms_inv = 1.0f / std::sqrt(ss / n_embd + 1e-6f); + for (int i = 0; i < n_embd; i++) x_normed[i] = x[i] * rms_inv; + } + + // LM head projection → draft token + StepOutput step_out; + if (head.shared_head_head && + tensor_to_floats(head.shared_head_head, shared_head_head_data) && + (int)shared_head_head_data.size() == n_vocab * n_embd) { + // Explicit per-head LM head (absent in 27B GGUF): [n_vocab x n_embd] + logits_buf.resize(n_vocab); + matvec_cpu(shared_head_head_data.data(), x_normed.data(), + logits_buf.data(), n_vocab, n_embd); + step_out.draft_token = argmax(logits_buf.data(), n_vocab); + step_out.draft_logit = logits_buf[step_out.draft_token]; + if (state_->draft_topk > 1) { + emit_topk_logprobs(logits_buf.data(), n_vocab, + state_->draft_topk, step_out); + } + } else { + // Standard path: shared LM head via target's project_hidden_to_*. + // For top-K (draft_topk > 1) prefer project_hidden_to_logits so we + // can populate the top-K logprob surface; fall back to the argmax + // path for K=1 or when the target lacks the logits virtual. + if (state_->draft_topk > 1) { + std::vector logits_buf_t; + int vocab = 0; + if (state_->target->project_hidden_to_logits(x_normed.data(), 1, + logits_buf_t, vocab) + && vocab > 0) { + step_out.draft_token = argmax(logits_buf_t.data(), vocab); + step_out.draft_logit = logits_buf_t[step_out.draft_token]; + emit_topk_logprobs(logits_buf_t.data(), vocab, + state_->draft_topk, step_out); + } else { + std::vector tok_out; + if (!state_->target->project_hidden_to_tokens(x_normed.data(), 1, tok_out) + || tok_out.empty()) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: project_hidden_to_tokens failed at head %d\n", h); + out.clear(); + return false; + } + step_out.draft_token = tok_out[0]; + step_out.draft_logit = 0.0f; + static bool warned = false; + if (!warned) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: draft_topk=%d requested but target " + "lacks project_hidden_to_logits; emitting argmax only.\n", + state_->draft_topk); + warned = true; + } + } + } else { + std::vector tok_out; + if (!state_->target->project_hidden_to_tokens(x_normed.data(), 1, tok_out) + || tok_out.empty()) { + std::fprintf(stderr, + "[qwen35_mtp] step_batch: project_hidden_to_tokens failed at head %d\n", h); + out.clear(); + return false; + } + step_out.draft_token = tok_out[0]; + step_out.draft_logit = 0.0f; + } + } + + out.push_back(std::move(step_out)); + + // ── Update chain state for next head ────────────────────────── + // Stash post-residual, pre-shared_head_norm hidden (Eq 22's output). + state_->last_hidden = x; + cur_token = out.back().draft_token; + } + + return true; +} + +const Qwen35MtpWeights & Qwen35MtpModule::weights() const { + return state_->weights; +} + +void Qwen35MtpModule::attach_weights_for_test(const Qwen35MtpWeights & w) { + state_->weights = w; + state_->loaded = true; + state_->last_hidden.assign(w.n_embd, 0.0f); + state_->initial_hidden_ptr = nullptr; + state_->initial_hidden_dim = 0; + // Allocate head_kv so step_batch's range attention can read/write slots. + // Tests typically exercise only a few positions; n_ctx=64 is plenty. + const int n_ctx = 64; + const int n_head_kv = w.n_head_kv; + const int key_len = w.n_key_length; + const int val_len = w.n_value_length; + state_->n_ctx = n_ctx; + state_->head_kv.clear(); + if (n_head_kv > 0 && key_len > 0 && val_len > 0 && w.n_heads > 0) { + state_->head_kv.resize(w.n_heads); + for (int h = 0; h < w.n_heads; h++) { + state_->head_kv[h].k.assign( + (size_t)n_ctx * n_head_kv * key_len, 0.0f); + state_->head_kv[h].v.assign( + (size_t)n_ctx * n_head_kv * val_len, 0.0f); + } + } +} + +void Qwen35MtpModule::set_initial_hidden(const float * h_prev, int dim) { + // Stash caller's pointer + dim. The pointer must remain valid for the + // duration of the next step_batch() call. + // The Shape B TRMBlock forward reads it. + state_->initial_hidden_ptr = h_prev; + state_->initial_hidden_dim = dim; +} + +bool Qwen35MtpModule::warm_head_kv(const int32_t * prompt_tokens, + int n_prompt, + int32_t prefill_next, + const float * hiddens) { + if (!state_->loaded || !state_->attached || !state_->target) { + std::fprintf(stderr, + "[qwen35_mtp] warm_head_kv: module not loaded/attached\n"); + return false; + } + if (n_prompt <= 0 || !prompt_tokens || !hiddens) return true; + + const int n_embd = state_->weights.n_embd; + const int n_heads = state_->weights.n_heads; + const int n_head_kv = state_->weights.n_head_kv; + const int key_len = state_->weights.n_key_length; + const int val_len = state_->weights.n_value_length; + const int kv_dim = n_head_kv * key_len; + const int v_total = n_head_kv * val_len; + const int rope_n_rot = std::min(64, key_len); + const float rope_theta = 1e7f; + + if (n_prompt >= state_->n_ctx) { + std::fprintf(stderr, + "[qwen35_mtp] warm_head_kv: n_prompt=%d exceeds head_kv capacity n_ctx=%d\n", + n_prompt, state_->n_ctx); + return false; + } + if ((int)state_->head_kv.size() < n_heads) { + std::fprintf(stderr, + "[qwen35_mtp] warm_head_kv: head_kv not allocated (size=%zu, expected=%d)\n", + state_->head_kv.size(), n_heads); + return false; + } + + // GPU path: when the backbone backend is available, batch-process all + // n_prompt positions in a single cgraph. Replaces the host-side per-slot + // dequant+matvec+upload loop (~2 s on Qwen3.6-27B with 69-token prompt) + // with one backend pass (~tens of ms). + if (state_->kv_ctx && state_->target->backend()) { + const int h = 0; // GGUF has n_heads=1; multi-head warmup would loop + const auto & head = state_->weights.heads[h]; + int rope_sections[4] = { 11, 11, 10, 0 }; + const int slot_start = 1; // slot 0 unused (no h_{-1} input) + ggml_backend_t backend = state_->target->backend(); + + // Pre-embed all input tokens on host. input_tok_seq[i] is the token at + // sequence position (i+1): prompt_tokens[i+1] for i+1 < n_prompt, + // prefill_next at slot n_prompt. + std::vector input_tok_seq(n_prompt); + for (int i = 0; i < n_prompt; i++) { + const int p = i + 1; + input_tok_seq[i] = (p < n_prompt) ? prompt_tokens[p] : prefill_next; + } + std::vector embed_seq((size_t)n_prompt * n_embd); + if (!state_->target->embed_tokens(input_tok_seq.data(), n_prompt, + embed_seq.data())) { + std::fprintf(stderr, "[qwen35_mtp gpu-warm] embed_tokens failed\n"); + return false; + } + + // MROPE positions, interleaved (4 axes per token) matching backbone. + std::vector pos_seq(4 * n_prompt); + for (int i = 0; i < n_prompt; i++) { + const int p = i + 1; + pos_seq[4 * i + 0] = p; + pos_seq[4 * i + 1] = p; + pos_seq[4 * i + 2] = p; + pos_seq[4 * i + 3] = 0; + } + + if (!build_qwen35_mtp_warm_graph(state_->warm_sg, head, + state_->head_kv[h].k_cache, + state_->head_kv[h].v_cache, + backend, + n_embd, n_head_kv, key_len, val_len, + rope_n_rot, rope_sections, + rope_theta, 1e-6f, + slot_start, n_prompt)) { + return false; + } + + ggml_backend_tensor_set(state_->warm_sg.inp_embed_seq, embed_seq.data(), + 0, sizeof(float) * embed_seq.size()); + ggml_backend_tensor_set(state_->warm_sg.inp_h_seq, hiddens, + 0, sizeof(float) * (size_t)n_prompt * n_embd); + ggml_backend_tensor_set(state_->warm_sg.inp_pos, pos_seq.data(), + 0, sizeof(int32_t) * pos_seq.size()); + + auto st = ggml_backend_graph_compute(backend, state_->warm_sg.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, + "[qwen35_mtp gpu-warm] graph_compute status=%d\n", (int)st); + return false; + } + return true; + } + + // We only warm head 0's K/V (the only head on this 27B GGUF). Multi-head + // GGUFs would warm each head; their h_prev would be the previous head's + // un-normed output, which we don't have during prefill. The handoff and + // redesign doc both pin γ_max=1 for this GGUF. + const int h = 0; + const auto & head = state_->weights.heads[h]; + + // Working buffers (sized once, reused across positions). + std::vector embed_buf(n_embd); + std::vector e_in(n_embd); + std::vector h_in(n_embd); + std::vector concat_buf(2 * n_embd); + std::vector x(n_embd); + std::vector x_normed(n_embd); + std::vector k_buf(kv_dim); + std::vector v_buf(v_total); + + // Dequantize head's per-position-invariant tensors once. + std::vector enorm_data, hnorm_data, eh_proj_data; + std::vector attn_norm_data, attn_k_data, attn_k_norm_data, attn_v_data; + if (!tensor_to_floats(head.enorm, enorm_data) || + (int)enorm_data.size() != n_embd || + !tensor_to_floats(head.hnorm, hnorm_data) || + (int)hnorm_data.size() != n_embd || + !tensor_to_floats(head.eh_proj, eh_proj_data) || + (int)eh_proj_data.size() != n_embd * 2 * n_embd || + !tensor_to_floats(head.attn_norm, attn_norm_data) || + (int)attn_norm_data.size() != n_embd || + !tensor_to_floats(head.attn_k, attn_k_data) || + (int)attn_k_data.size() != kv_dim * n_embd || + !tensor_to_floats(head.attn_v, attn_v_data) || + (int)attn_v_data.size() != v_total * n_embd || + !tensor_to_floats(head.attn_k_norm, attn_k_norm_data) || + (int)attn_k_norm_data.size() != key_len) { + std::fprintf(stderr, + "[qwen35_mtp] warm_head_kv: failed to dequant head tensors\n"); + return false; + } + + auto & kv = state_->head_kv[h]; + // Slot p of head_kv represents sequence position p (matching backbone KV + // slot p, matching RoPE position p, matching step_batch's draft_pos = base_pos+h). + // The head's K/V at slot p use inputs (h_{p-1}, t_p): backbone post-norm + // hidden at the PREVIOUS position and the input token AT position p. + // Slot 0 has no h_{-1}, so it stays zero (Q at later slots will see slot 0 + // as a near-zero contribution; softmax shifts mass to populated slots). + // For p in [1, n_prompt-1]: input_tok = prompt_tokens[p]. + // For p = n_prompt: input_tok = prefill_next, h_{p-1} = h_{n_prompt-1} + // (the last prefill hidden). + const int last_slot = n_prompt; // inclusive + if (last_slot >= state_->n_ctx) { + std::fprintf(stderr, + "[qwen35_mtp] warm_head_kv: last_slot=%d exceeds n_ctx=%d\n", + last_slot, state_->n_ctx); + return false; + } + for (int p = 1; p <= last_slot; p++) { + const int32_t input_tok = + (p < n_prompt) ? prompt_tokens[p] : prefill_next; + + if (!state_->target->embed_tokens(&input_tok, 1, embed_buf.data())) { + std::fprintf(stderr, + "[qwen35_mtp] warm_head_kv: embed_tokens failed at p=%d\n", p); + return false; + } + rmsnorm_cpu(embed_buf.data(), enorm_data.data(), e_in.data(), n_embd); + + // h_{p-1}: backbone post-norm hidden at the PREVIOUS sequence position. + const float * h_prev_p = hiddens + (size_t)(p - 1) * n_embd; + rmsnorm_cpu(h_prev_p, hnorm_data.data(), h_in.data(), n_embd); + + // Concat order [e_in; h_in] matches llama.cpp PR #22673 graph_mtp. + std::copy(e_in.begin(), e_in.end(), concat_buf.begin()); + std::copy(h_in.begin(), h_in.end(), concat_buf.begin() + n_embd); + matvec_cpu(eh_proj_data.data(), concat_buf.data(), x.data(), + n_embd, 2 * n_embd); + + rmsnorm_cpu(x.data(), attn_norm_data.data(), x_normed.data(), n_embd); + + matvec_cpu(attn_k_data.data(), x_normed.data(), k_buf.data(), + kv_dim, n_embd); + matvec_cpu(attn_v_data.data(), x_normed.data(), v_buf.data(), + v_total, n_embd); + + per_head_rmsnorm(k_buf.data(), attn_k_norm_data.data(), n_head_kv, key_len); + rope_cpu(k_buf.data(), n_head_kv, key_len, rope_n_rot, p, rope_theta); + + const size_t k_slot_off = (size_t)p * n_head_kv * key_len; + const size_t v_slot_off = (size_t)p * n_head_kv * val_len; + std::memcpy(kv.k.data() + k_slot_off, k_buf.data(), + sizeof(float) * (size_t)n_head_kv * key_len); + std::memcpy(kv.v.data() + v_slot_off, v_buf.data(), + sizeof(float) * (size_t)n_head_kv * val_len); + } + // Prefill is done. From this point on, every verify_batch is a decode + // step whose ONLY hidden-sequence consumer is hidden_at_pos(base_pos-1) + // (the chain's iter-0 h_prev seed). Tell the target to download only + // that single row from all_norm_hidden / all_h_pre_norm instead of the + // full [n_tokens, n_embd] tensor — collapses the 2x per-verify ~80 KB + // device->host sync to a 2x ~20 KB sync (hidden_dim=5120, D+1=4 tokens + // baseline) and erases the WSL2 cudaStreamSynchronize scheduler tax + // that dominated decode-side verify_batch in the verify_prof traces. + if (auto * t = dynamic_cast(state_->target)) { + t->set_hidden_capture_mode( + Qwen35DFlashTarget::VerifyCaptureMode::LAST_ROW_ONLY); + } + return true; +} + +const float * Qwen35MtpModule::test_initial_hidden_ptr() const { + return state_->initial_hidden_ptr; +} +int Qwen35MtpModule::test_initial_hidden_dim() const { + return state_->initial_hidden_dim; +} + +// Bug #5 fix: shape-only graph cached per (head_idx, fa_window, fused, topk). +// Per-call slot routing is uploaded via push_kv_slot_inputs_() below. + +// Push the runtime KV routing inputs (write slot, read idxs, mask) for the +// current draft_pos / kv_len. fa_max is baked into the graph at build time. +static void push_kv_slot_inputs_(Qwen35MtpStepGraph * sg, + int draft_pos, int kv_len, + int n_head_kv) { + const int fa_max = sg->fa_max; + const int fa_win = sg->fa_window; + const int fa_kv_lo = (fa_win > 0 && kv_len > fa_win) ? (kv_len - fa_win) : 0; + const int fa_kv_n = std::min(kv_len - fa_kv_lo, fa_max); + + const int64_t widx = (int64_t)draft_pos; + ggml_backend_tensor_set(sg->inp_kv_idx_write, &widx, 0, sizeof(int64_t)); + + std::vector ridx((size_t)fa_max * n_head_kv); + for (int h = 0; h < n_head_kv; ++h) { + int32_t * row = ridx.data() + (size_t)h * fa_max; + for (int i = 0; i < fa_max; ++i) { + row[i] = (i < fa_kv_n) ? (fa_kv_lo + i) : 0; // unused rows masked to -INF + } + } + ggml_backend_tensor_set(sg->inp_kv_idxs_read, ridx.data(), 0, + sizeof(int32_t) * ridx.size()); + + std::vector mask((size_t)fa_max, 0); + const uint16_t neg_inf_f16 = 0xFC00; + for (int i = fa_kv_n; i < fa_max; ++i) mask[i] = neg_inf_f16; + ggml_backend_tensor_set(sg->inp_kv_mask, mask.data(), 0, + sizeof(uint16_t) * mask.size()); +} + +Qwen35MtpStepGraph * Qwen35MtpModule::get_or_build_step_graph_(int head_idx) { + if (head_idx < 0 || head_idx >= (int)state_->weights.heads.size()) { + return nullptr; + } + if (head_idx >= (int)state_->head_kv.size()) return nullptr; + + const int n_embd = state_->weights.n_embd; + const int n_head = state_->weights.n_head_count; + const int n_head_kv = state_->weights.n_head_kv; + const int key_len = state_->weights.n_key_length; + const int val_len = state_->weights.n_value_length; + const int ffn_len = state_->weights.n_ffn_length; + const int n_rot = std::min(64, key_len); + const float rope_freq_base = 1e7f; + const float rms_eps = 1e-6f; + int rope_sections[4] = { 11, 11, 10, 0 }; + + const int fa_win = state_->target ? state_->target->fa_window() : 0; + ggml_tensor * lm_head = state_->target ? state_->target->lm_head_weight() : nullptr; + const bool fused = (lm_head != nullptr); + const int topk_k = (fused && state_->draft_topk > 1) ? state_->draft_topk : 0; + + // Find matching cached entry; else pick first empty / oldest slot. + int hit = -1, free_slot = -1; + for (int i = 0; i < (int)state_->step_sg_cache.size(); ++i) { + const auto & k = state_->step_sg_cache[i].first; + if (state_->step_sg_cache[i].second && + k.head_idx == head_idx && k.fa_window == fa_win && + k.fused_lm_head == fused && k.topk_k == topk_k) { + hit = i; break; + } + if (!state_->step_sg_cache[i].second && free_slot < 0) free_slot = i; + } + if (hit >= 0) return state_->step_sg_cache[hit].second.get(); + if (free_slot < 0) free_slot = 0; // evict slot 0 (FIFO is fine; cap=4) + + auto & slot = state_->step_sg_cache[free_slot]; + if (slot.second) qwen35_mtp_step_graph_free(*slot.second); + else slot.second.reset(new Qwen35MtpStepGraph()); + + const auto & head = state_->weights.heads[head_idx]; + if (!build_qwen35_mtp_step_graph(*slot.second, head, + state_->head_kv[head_idx].k_cache, + state_->head_kv[head_idx].v_cache, + state_->target->backend(), + n_embd, n_head, n_head_kv, + key_len, val_len, ffn_len, + n_rot, rope_sections, + rope_freq_base, rms_eps, + state_->n_ctx, + fa_win, lm_head, topk_k)) { + std::fprintf(stderr, + "[qwen35_mtp] get_or_build_step_graph_: build failed head=%d\n", + head_idx); + slot.second.reset(); + slot.first = State::StepGraphKey{}; + return nullptr; + } + slot.first = State::StepGraphKey{head_idx, fa_win, fused, topk_k}; + return slot.second.get(); +} + +// Per-call cgraph on the backbone backend; cached per (head_idx, draft_pos) +// in state_->step_sg_cache. When the bound target exposes its +// lm_head_weight() the graph also fuses the LM-head matmul + argmax so we +// skip the hidden -> host -> separate-cgraph round trip per call. +bool Qwen35MtpModule::step_batch_gpu_(int32_t current_token, + int base_pos, + std::vector & out) { + const int n_embd = state_->weights.n_embd; + const int n_heads = state_->weights.n_heads; + + out.clear(); + out.reserve(n_heads); + + ggml_backend_t backend = state_->target->backend(); + int32_t cur_token = current_token; + std::vector embed_buf(n_embd); + + for (int h = 0; h < n_heads; h++) { + const int draft_pos = base_pos + h; + if (draft_pos >= state_->n_ctx) { + std::fprintf(stderr, + "[qwen35_mtp gpu] draft_pos=%d exceeds n_ctx=%d\n", + draft_pos, state_->n_ctx); + out.clear(); + return false; + } + const int kv_len = draft_pos + 1; + + // ── Select h_prev (same priority as CPU path) ─────────────────── + // Pre-output-norm preferred (PR #22673 t_h_pre_norm) — head's + // hnorm normalises h_prev internally and post-norm seed double- + // normalises. Fall back to post-norm if the adapter did not + // capture the pre-norm sequence this verify_batch. + const float * h_prev = nullptr; + if (h == 0) { + if (base_pos >= 1) { + h_prev = state_->target->hidden_at_pos_pre_norm(base_pos - 1); + if (!h_prev) { + h_prev = state_->target->hidden_at_pos(base_pos - 1); + } + } + if (!h_prev) h_prev = state_->target->last_hidden(); + if (!h_prev && state_->initial_hidden_ptr && + state_->initial_hidden_dim == n_embd) { + h_prev = state_->initial_hidden_ptr; + } + if (!h_prev) { + std::fprintf(stderr, + "[qwen35_mtp gpu] no hidden available at base_pos=%d\n", + base_pos); + out.clear(); + return false; + } + } else { + // h>0 chain: use the head's own previous output (kept on host as + // last_hidden). Only matters when n_heads > 1; the 27B GGUF has + // n_heads=1 so this branch is rarely exercised. + h_prev = state_->last_hidden.data(); + } + // Embed cur_token via target (already on host). + if (!state_->target->embed_tokens(&cur_token, 1, embed_buf.data())) { + std::fprintf(stderr, "[qwen35_mtp gpu] embed_tokens failed h=%d\n", h); + out.clear(); + return false; + } + + // Get-or-build the shape-only step graph. + Qwen35MtpStepGraph * sg = get_or_build_step_graph_(h); + if (!sg) { out.clear(); return false; } + + // Upload inputs. Pass h_prev directly (no scratch memcpy — task E). + ggml_backend_tensor_set(sg->inp_embed, + embed_buf.data(), 0, sizeof(float) * n_embd); + ggml_backend_tensor_set(sg->inp_h_prev, + h_prev, 0, sizeof(float) * n_embd); + // MROPE positions: text-only mode (axes 0..2 = position, axis 3 = 0). + const int32_t pos[4] = { draft_pos, draft_pos, draft_pos, 0 }; + ggml_backend_tensor_set(sg->inp_pos, pos, 0, sizeof(pos)); + push_kv_slot_inputs_(sg, draft_pos, kv_len, state_->weights.n_head_kv); + + auto st = ggml_backend_graph_compute(backend, sg->gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "[qwen35_mtp gpu] graph_compute status=%d\n", (int)st); + out.clear(); + return false; + } + + StepOutput so; + if (sg->fused_lm_head && sg->out_argmax_token) { + // Fused path: read the argmax (and optional full logits for + // top-K) directly from the cached graph's outputs — no separate + // projection cgraph, no hidden -> host round trip. + int32_t tok = 0; + ggml_backend_tensor_get(sg->out_argmax_token, &tok, 0, sizeof(int32_t)); + so.draft_token = tok; + so.draft_logit = 0.0f; + if (sg->out_logits && state_->draft_topk > 1) { + const int vocab = (int)sg->out_logits->ne[0]; + std::vector logits_buf((size_t)vocab); + ggml_backend_tensor_get(sg->out_logits, logits_buf.data(), + 0, sizeof(float) * vocab); + so.draft_logit = logits_buf[tok]; + emit_topk_logprobs(logits_buf.data(), vocab, + state_->draft_topk, so); + } + // Chain state: when n_heads > 1 the next head's h_prev needs the + // hidden. Skip the download otherwise to keep the fused path + // zero-host-roundtrip. Read PRE-shared_head_norm hidden — feeding + // post-norm here causes the next iter's `hnorm` to double- + // normalise (see qwen35_mtp_graph.cpp pre-norm output comment; + // mirrors llama.cpp PR #22673 `t_h_pre_norm` design). + if (n_heads > 1) { + std::vector h_pre(n_embd); + ggml_backend_tensor_get(sg->out_h_pre_norm, h_pre.data(), + 0, sizeof(float) * n_embd); + state_->last_hidden = std::move(h_pre); + } + } else { + // Unfused fallback: read hidden, project on host via target. + std::vector x_normed(n_embd); + ggml_backend_tensor_get(sg->out_x_normed, + x_normed.data(), 0, sizeof(float) * n_embd); + if (state_->draft_topk > 1) { + std::vector logits_buf; + int vocab = 0; + if (state_->target->project_hidden_to_logits(x_normed.data(), 1, + logits_buf, vocab) && + vocab > 0) { + so.draft_token = argmax(logits_buf.data(), vocab); + so.draft_logit = logits_buf[so.draft_token]; + emit_topk_logprobs(logits_buf.data(), vocab, + state_->draft_topk, so); + } else { + std::vector tok_out; + if (!state_->target->project_hidden_to_tokens(x_normed.data(), 1, + tok_out) || + tok_out.empty()) { + std::fprintf(stderr, + "[qwen35_mtp gpu] project_hidden_to_tokens failed h=%d\n", h); + out.clear(); + return false; + } + so.draft_token = tok_out[0]; + so.draft_logit = 0.0f; + static bool warned = false; + if (!warned) { + std::fprintf(stderr, + "[qwen35_mtp gpu] draft_topk=%d requested but target " + "lacks project_hidden_to_logits; emitting argmax only.\n", + state_->draft_topk); + warned = true; + } + } + } else { + std::vector tok_out; + if (!state_->target->project_hidden_to_tokens(x_normed.data(), 1, + tok_out) || + tok_out.empty()) { + std::fprintf(stderr, + "[qwen35_mtp gpu] project_hidden_to_tokens failed h=%d\n", h); + out.clear(); + return false; + } + so.draft_token = tok_out[0]; + so.draft_logit = 0.0f; + } + // last_hidden must be PRE-shared_head_norm (chain h_prev contract; + // see qwen35_mtp_graph.cpp pre-norm output comment and + // llama.cpp PR #22673 `t_h_pre_norm`). `x_normed` is consumed + // above for the LM-head projection only. + if (n_heads > 1) { + std::vector h_pre(n_embd); + ggml_backend_tensor_get(sg->out_h_pre_norm, h_pre.data(), + 0, sizeof(float) * n_embd); + state_->last_hidden = std::move(h_pre); + } + } + out.push_back(so); + cur_token = so.draft_token; + } + return true; +} + +// Autoregressive chain draft. Reuses head 0 (the only NextN head on +// the Qwen3.6-27B GGUF) `chain_depth` times, feeding the head's own +// post-shared_head_norm hidden back as h_prev for the next iteration. +// Per-iter step graphs are cached in state_->step_sg_cache. +// +// CPU stub path (no backend / no kv_ctx): degrade gracefully to the default +// `step_batch`+clamp. Unit tests exercising the CPU forward at depth=1 +// remain byte-identical to the old behaviour. +bool Qwen35MtpModule::step_chain(int32_t current_token, + int base_pos, + int chain_depth, + std::vector & out) { + out.clear(); + if (!state_->loaded || !state_->attached) return false; + if (chain_depth <= 0) chain_depth = 1; + + // Single-iter fast path: byte-identical to the established step_batch + // contract. Avoids the additional state_->last_hidden plumbing in the + // depth>1 loop below. + if (chain_depth == 1) { + std::vector tmp; + if (!step_batch(current_token, base_pos, tmp)) return false; + if (!tmp.empty()) out.push_back(std::move(tmp.front())); + return true; + } + + // CPU stub fallback for depth>1 — not exercised by production today + // (the only depth>1 caller goes through the GPU path); degrade to one + // step_batch call so tests that swap out the backend still link. + const bool gpu_ready = (state_->kv_ctx && state_->target && + state_->target->backend()); + if (!gpu_ready) { + std::vector tmp; + if (!step_batch(current_token, base_pos, tmp)) return false; + if (!tmp.empty()) out.push_back(std::move(tmp.front())); + return true; + } + + // ── GPU multi-iter chain ───────────────────────────────────────── + // Per-iter step graphs are pulled from state_->step_sg_cache via + // get_or_build_step_graph_(). First-pass at each draft_pos is a build; + // subsequent calls are a pure tensor_set + compute. When the bound + // target exposes lm_head_weight() the fused argmax output is read from + // the graph directly so we skip the projection-cgraph round trip. + out.reserve(chain_depth); + + ggml_backend_t backend = state_->target->backend(); + + // The Qwen3.6-27B GGUF has n_heads=1; chain depth replays that single + // head against successive draft positions. If a future GGUF lands with + // multiple physical NextN heads, this implementation would still chain + // on head 0 only — multi-head + chain interaction is a Phase >A concern. + const int h = 0; + if ((int)state_->weights.heads.size() <= h || + (int)state_->head_kv.size() <= h) { + std::fprintf(stderr, + "[qwen35_mtp gpu chain] head 0 missing (heads=%zu head_kv=%zu)\n", + state_->weights.heads.size(), state_->head_kv.size()); + return false; + } + + const int n_embd = state_->weights.n_embd; + int32_t cur_token = current_token; + std::vector embed_buf(n_embd); + +#ifdef DFLASH_MTP_PROFILE + // Profiling accumulators (DFLASH_MTP_PROFILE=1). All in ms. + const bool prof_on = mtp_profile_enabled(); + double prof_sum_embed = 0.0, prof_sum_set = 0.0, prof_sum_compute = 0.0; + double prof_sum_get_x = 0.0, prof_sum_get_h = 0.0, prof_sum_get_argmax = 0.0; + double prof_sum_total = 0.0; + int prof_iters = 0; +#endif // DFLASH_MTP_PROFILE + + for (int it = 0; it < chain_depth; it++) { + const int draft_pos = base_pos + it; + if (draft_pos >= state_->n_ctx) { + std::fprintf(stderr, + "[qwen35_mtp gpu chain] draft_pos=%d exceeds n_ctx=%d (iter=%d)\n", + draft_pos, state_->n_ctx, it); + return false; + } + const int kv_len = draft_pos + 1; + + // h_prev resolution mirrors step_batch_gpu_: + // iter 0: pull from target (hidden_at_pos_pre_norm > hidden_at_pos + // > last_hidden > initial). Pre-norm preferred to mirror + // llama.cpp PR #22673 `t_h_pre_norm` — the head's hnorm + // normalises h_prev internally so the post-output-norm + // tensor double-normalises and crushes D>=2 accept. + // iter h>0: use state_->last_hidden written by the previous iter + // from out_h_pre_norm (the pre-shared_head_norm hidden; + // commit 9850ec9). + const float * h_prev = nullptr; + if (it == 0) { + if (base_pos >= 1) { + h_prev = state_->target->hidden_at_pos_pre_norm(base_pos - 1); + if (!h_prev) { + // Fallback: adapter did not capture the pre-norm + // sequence this verify_batch. Post-norm degrades + // D>=2 accept (commit 9850ec9 fixed the intra-iter + // re-feed; this is the OUTER seed and silently + // returns to the pre-9850ec9 regime when it fires). + // Warn once per process so production knows. + static bool warned_post_norm = false; + if (!warned_post_norm) { + std::fprintf(stderr, + "[qwen35_mtp gpu chain] WARN: hidden_at_pos_pre_norm " + "returned null at base_pos=%d, falling back to " + "post-norm hidden. This silently undoes the " + "PR #22673 t_h_pre_norm fix at iter 0 and crushes " + "D>=2 accept. Caller should call " + "enable_hidden_seq_capture(true) on the target " + "BEFORE the prefill verify_batch. (warning fires once)\n", + base_pos); + warned_post_norm = true; + } + h_prev = state_->target->hidden_at_pos(base_pos - 1); + } + } + if (!h_prev) h_prev = state_->target->last_hidden(); + if (!h_prev && state_->initial_hidden_ptr && + state_->initial_hidden_dim == n_embd) { + h_prev = state_->initial_hidden_ptr; + } + if (!h_prev) { + std::fprintf(stderr, + "[qwen35_mtp gpu chain] no hidden available at base_pos=%d\n", + base_pos); + return false; + } + } else { + if ((int)state_->last_hidden.size() != n_embd) { + std::fprintf(stderr, + "[qwen35_mtp gpu chain] last_hidden missing for iter %d\n", it); + return false; + } + h_prev = state_->last_hidden.data(); + } + +#ifdef DFLASH_MTP_PROFILE + prof_clock::time_point t_iter0 = prof_on ? prof_clock::now() : prof_clock::time_point{}; + prof_clock::time_point t0 = prof_on ? prof_clock::now() : prof_clock::time_point{}; +#endif // DFLASH_MTP_PROFILE + if (!state_->target->embed_tokens(&cur_token, 1, embed_buf.data())) { + std::fprintf(stderr, + "[qwen35_mtp gpu chain] embed_tokens failed iter=%d\n", it); + return false; + } +#ifdef DFLASH_MTP_PROFILE + const double t_embed = prof_on ? prof_ms_since(t0) : 0.0; +#endif // DFLASH_MTP_PROFILE + + Qwen35MtpStepGraph * sg = get_or_build_step_graph_(h); + if (!sg) return false; + +#ifdef DFLASH_MTP_PROFILE + t0 = prof_on ? prof_clock::now() : prof_clock::time_point{}; +#endif + ggml_backend_tensor_set(sg->inp_embed, + embed_buf.data(), 0, sizeof(float) * n_embd); +#ifdef DFLASH_MTP_PROFILE + const double t_set_embed = prof_on ? prof_ms_since(t0) : 0.0; + t0 = prof_on ? prof_clock::now() : prof_clock::time_point{}; +#endif + // Pass h_prev directly (no scratch memcpy — task E). + ggml_backend_tensor_set(sg->inp_h_prev, + h_prev, 0, sizeof(float) * n_embd); +#ifdef DFLASH_MTP_PROFILE + const double t_set_hprev = prof_on ? prof_ms_since(t0) : 0.0; + t0 = prof_on ? prof_clock::now() : prof_clock::time_point{}; +#endif + const int32_t pos[4] = { draft_pos, draft_pos, draft_pos, 0 }; + ggml_backend_tensor_set(sg->inp_pos, pos, 0, sizeof(pos)); + push_kv_slot_inputs_(sg, draft_pos, kv_len, state_->weights.n_head_kv); +#ifdef DFLASH_MTP_PROFILE + const double t_set_pos = prof_on ? prof_ms_since(t0) : 0.0; + t0 = prof_on ? prof_clock::now() : prof_clock::time_point{}; +#endif + auto st = ggml_backend_graph_compute(backend, sg->gf); +#ifdef DFLASH_MTP_PROFILE + const double t_compute = prof_on ? prof_ms_since(t0) : 0.0; +#endif + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, + "[qwen35_mtp gpu chain] graph_compute status=%d iter=%d\n", + (int)st, it); + return false; + } + + StepOutput so; + // h_prev for the next iter must be PRE-shared_head_norm (not post-norm); + // feeding post-norm double-normalises against head's hnorm (see PR #22673). + // On the fused-LM-head greedy path neither x_normed nor h_pre is needed — + // the argmax is already on-device; skip both downloads. + std::vector x_normed; + const bool need_x_normed = + !(sg->fused_lm_head && sg->out_argmax_token); +#ifdef DFLASH_MTP_PROFILE + double t_get_x = 0.0; +#endif + if (need_x_normed) { + x_normed.resize(n_embd); +#ifdef DFLASH_MTP_PROFILE + t0 = prof_on ? prof_clock::now() : prof_clock::time_point{}; +#endif + ggml_backend_tensor_get(sg->out_x_normed, + x_normed.data(), 0, sizeof(float) * n_embd); +#ifdef DFLASH_MTP_PROFILE + t_get_x = prof_on ? prof_ms_since(t0) : 0.0; +#endif + } + // Only download h_pre when ANOTHER iteration follows that will + // consume it as h_prev. On the last chain step the value is + // discarded — skip the device->host transfer entirely. + const bool need_h_pre = (it + 1 < chain_depth); + std::vector h_pre; +#ifdef DFLASH_MTP_PROFILE + double t_get_h = 0.0; +#endif + if (need_h_pre) { + h_pre.resize(n_embd); +#ifdef DFLASH_MTP_PROFILE + t0 = prof_on ? prof_clock::now() : prof_clock::time_point{}; +#endif + ggml_backend_tensor_get(sg->out_h_pre_norm, + h_pre.data(), 0, sizeof(float) * n_embd); +#ifdef DFLASH_MTP_PROFILE + t_get_h = prof_on ? prof_ms_since(t0) : 0.0; +#endif + } + +#ifdef DFLASH_MTP_PROFILE + double t_get_argmax = 0.0; +#endif + if (sg->fused_lm_head && sg->out_argmax_token) { + int32_t tok = 0; +#ifdef DFLASH_MTP_PROFILE + t0 = prof_on ? prof_clock::now() : prof_clock::time_point{}; +#endif + ggml_backend_tensor_get(sg->out_argmax_token, &tok, 0, sizeof(int32_t)); +#ifdef DFLASH_MTP_PROFILE + t_get_argmax = prof_on ? prof_ms_since(t0) : 0.0; +#endif + so.draft_token = tok; + so.draft_logit = 0.0f; + if (sg->out_logits && state_->draft_topk > 1) { + const int vocab = (int)sg->out_logits->ne[0]; + std::vector logits_buf((size_t)vocab); + ggml_backend_tensor_get(sg->out_logits, logits_buf.data(), + 0, sizeof(float) * vocab); + so.draft_logit = logits_buf[tok]; + emit_topk_logprobs(logits_buf.data(), vocab, + state_->draft_topk, so); + } + } else if (state_->draft_topk > 1) { + std::vector logits_buf; + int vocab = 0; + if (state_->target->project_hidden_to_logits(x_normed.data(), 1, + logits_buf, vocab) && + vocab > 0) { + so.draft_token = argmax(logits_buf.data(), vocab); + so.draft_logit = logits_buf[so.draft_token]; + emit_topk_logprobs(logits_buf.data(), vocab, + state_->draft_topk, so); + } else { + std::vector tok_out; + if (!state_->target->project_hidden_to_tokens(x_normed.data(), 1, + tok_out) || + tok_out.empty()) { + std::fprintf(stderr, + "[qwen35_mtp gpu chain] project_hidden_to_tokens failed iter=%d\n", it); + return false; + } + so.draft_token = tok_out[0]; + so.draft_logit = 0.0f; + } + } else { + std::vector tok_out; + if (!state_->target->project_hidden_to_tokens(x_normed.data(), 1, + tok_out) || + tok_out.empty()) { + std::fprintf(stderr, + "[qwen35_mtp gpu chain] project_hidden_to_tokens failed iter=%d\n", it); + return false; + } + so.draft_token = tok_out[0]; + so.draft_logit = 0.0f; + } + out.push_back(so); + + // Stash PRE-shared_head_norm hidden for next iter's h_prev and + // advance cur_token to the freshly drafted token. Post-norm here + // (the previous behaviour) compounded a distribution drift per + // depth — see the pre-norm-hidden comment above. Skipped on the + // final iter (h_pre stays empty when need_h_pre=false). + if (need_h_pre) { + state_->last_hidden = std::move(h_pre); + } + cur_token = so.draft_token; + +#ifdef DFLASH_MTP_PROFILE + if (prof_on) { + const double t_total = prof_ms_since(t_iter0); + std::fprintf(stderr, + "[mtp_prof iter=%d] embed=%.3f set=%.3f.%.3f.%.3f " + "compute=%.3f get=%.3f.%.3f argmax=%.3f total=%.3f (ms)\n", + it, t_embed, t_set_embed, t_set_hprev, t_set_pos, + t_compute, t_get_x, t_get_h, t_get_argmax, t_total); + prof_sum_embed += t_embed; + prof_sum_set += t_set_embed + t_set_hprev + t_set_pos; + prof_sum_compute += t_compute; + prof_sum_get_x += t_get_x; + prof_sum_get_h += t_get_h; + prof_sum_get_argmax += t_get_argmax; + prof_sum_total += t_total; + ++prof_iters; + } +#endif // DFLASH_MTP_PROFILE + } + +#ifdef DFLASH_MTP_PROFILE + if (prof_on && prof_iters > 0) { + const double sum_get = prof_sum_get_x + prof_sum_get_h + prof_sum_get_argmax; + const double denom = prof_sum_total > 0.0 ? prof_sum_total : 1.0; + const double pct_embed = 100.0 * prof_sum_embed / denom; + const double pct_set = 100.0 * prof_sum_set / denom; + const double pct_compute = 100.0 * prof_sum_compute / denom; + const double pct_get = 100.0 * sum_get / denom; + std::fprintf(stderr, + "[mtp_prof chain depth=%d iters=%d avg_iter=%.3f ms " + "breakdown_pct: embed=%.1f%% set=%.1f%% compute=%.1f%% get=%.1f%% " + "(get_x=%.3f get_h=%.3f get_argmax=%.3f ms total)]\n", + chain_depth, prof_iters, prof_sum_total / prof_iters, + pct_embed, pct_set, pct_compute, pct_get, + prof_sum_get_x, prof_sum_get_h, prof_sum_get_argmax); + } +#endif // DFLASH_MTP_PROFILE + return true; +} + +} // namespace dflash::common::mtp diff --git a/dflash/src/qwen35/qwen35_mtp.h b/dflash/src/qwen35/qwen35_mtp.h new file mode 100644 index 00000000..0456ba25 --- /dev/null +++ b/dflash/src/qwen35/qwen35_mtp.h @@ -0,0 +1,210 @@ +// qwen35_mtp.h — Native-heads MTP module for unsloth Qwen3.6 GGUFs. +// +// Qwen3.6's GGUF architecture is `qwen35` (same backbone). The MTP heads are +// stored as additional per-layer tensors on the last `qwen35.nextn_predict_layers` +// blocks of the GGUF, following DeepSeek-V3 / NextN conventions: +// +// blk.{bid}.nextn.eh_proj : [2*n_embd, n_embd] — concat[embed;hidden] -> embed +// blk.{bid}.nextn.enorm : [n_embd] — embed-side norm +// blk.{bid}.nextn.hnorm : [n_embd] — hidden-side norm +// blk.{bid}.nextn.embed_tokens : [n_embd, n_vocab] — optional; shared with backbone if absent +// blk.{bid}.nextn.shared_head_head : [n_embd, n_vocab] — optional; shared with backbone if absent +// blk.{bid}.nextn.shared_head_norm : [n_embd] — optional; shared with backbone if absent +// +// Each MTP head also carries its own transformer block tensors (attn_q/k/v/o, +// ffn_*, ssm_*) at the same layer index — these were already loaded for the +// backbone forward (qwen35 path). The MTP forward reuses them via the +// supplied Qwen35Backend / DFlashTarget. +// +// Per unsloth: recommended γ ≤ 2 (accept rate drops from ~83% at γ=1 to +// ~50% at γ=4). `max_gamma()` honors `nextn_predict_layers` from GGUF. +// +// This module implements INativeMtp from mtp_interface.h: `step_batch` emits +// up to num_heads() draft tokens per call. + +#pragma once + +#include "common/mtp_interface.h" + +#include +#include +#include +#include +#include + +struct ggml_tensor; +struct ggml_context; + +namespace dflash::common { + +struct DFlashTarget; + +namespace mtp { + +// One Qwen3.6 MTP head's weights. There are `n_heads` such entries; head i +// corresponds to GGUF block index `(n_layer - n_heads + i)`. +// +// Shape B (DeepSeek-V3 NextN): each head owns a full transformer-block at its +// GGUF block index. These tensors are loaded from blk.{layer_idx}.* (no nextn. +// prefix). See qwen35_mtp_redesign.md for the verified GGUF tensor inventory. +struct Qwen35MtpHeadWeights { + int layer_idx = -1; // absolute GGUF block index + // NextN-specific tensors (required) + ggml_tensor * eh_proj = nullptr; // [n_embd, 2*n_embd] + ggml_tensor * enorm = nullptr; // [n_embd] + ggml_tensor * hnorm = nullptr; // [n_embd] + ggml_tensor * embed_tokens = nullptr; // optional (nullable -> use backbone) + ggml_tensor * shared_head_head = nullptr; // optional (nullable -> use backbone) + ggml_tensor * shared_head_norm = nullptr; // optional (nullable -> use backbone) + // Head-owned transformer block (Shape B — required, not shared with backbone) + ggml_tensor * attn_norm = nullptr; // [n_embd] + ggml_tensor * attn_q = nullptr; // [n_embd, head_count * key_length] + ggml_tensor * attn_q_norm = nullptr; // [key_length] + ggml_tensor * attn_k = nullptr; // [n_embd, head_count_kv * key_length] + ggml_tensor * attn_k_norm = nullptr; // [key_length] + ggml_tensor * attn_v = nullptr; // [n_embd, head_count_kv * value_length] + ggml_tensor * attn_output = nullptr; // [head_count * value_length, n_embd] + ggml_tensor * post_attention_norm = nullptr; // [n_embd] + ggml_tensor * ffn_gate = nullptr; // [n_embd, ffn_length] + ggml_tensor * ffn_up = nullptr; // [n_embd, ffn_length] + ggml_tensor * ffn_down = nullptr; // [ffn_length, n_embd] +}; + +struct Qwen35MtpWeights { + int n_embd = 0; + int n_vocab = 0; + int n_heads = 0; // == nextn_predict_layers + int n_backbone_layers = 0; // total backbone n_layer + std::vector heads; // size == n_heads + // GGUF metadata copies for cross-checks + std::string backbone_arch; // e.g. "qwen35" + std::string base_model_name; // e.g. "Qwen3.6-27B" + // Attention sizing — read from GGUF; needed to size per-head KV buffers. + int n_head_count = 0; // qwen35.attention.head_count + int n_head_kv = 0; // qwen35.attention.head_count_kv + int n_key_length = 0; // qwen35.attention.key_length + int n_value_length = 0; // qwen35.attention.value_length + int n_ffn_length = 0; // qwen35.feed_forward_length +}; + +// Load Qwen3.6 MTP weights from a GGUF file. Returns false if the file does +// not contain NextN tensors (i.e. it is a non-MTP GGUF). The `ctx` parameter +// owns the loaded tensors' lifetime; pass the same ctx used for backbone. +// +// `expected_n_embd` / `expected_n_vocab` are sanity-checked against GGUF +// metadata; pass values from the bound target. +bool load_qwen35_mtp_weights(const std::string & gguf_path, + ggml_context * ctx, + int expected_n_embd, + int expected_n_vocab, + Qwen35MtpWeights & out_weights, + std::string & out_error); + +// Concrete INativeMtp impl for Qwen3.6 (unsloth -MTP-GGUF). Wraps the +// loaded weights + a bound DFlashTarget (typically Qwen35Backend's +// dflash_target()) which provides backbone embedding, KV, LM head. +class Qwen35MtpModule : public INativeMtp { +public: + Qwen35MtpModule(); + ~Qwen35MtpModule() override; + + // One-shot construction: load weights then attach in a single call. + // The returned module is ready for step_batch() once attach() succeeds. + bool init(const std::string & gguf_path, + DFlashTarget * target, + std::string & out_error); + + // Integration construction: bind NextN tensors from the backbone GGUF + // context. The context lifetime stays with the caller; this module owns + // only the CPU buffer used to materialize the NextN tensors it reads. + bool init(const std::string & gguf_path, + ggml_context * ctx, + DFlashTarget * target, + std::string & out_error); + + // ── IMtpModule ────────────────────────────────────────────────── + int max_gamma() const override; + int effective_gamma() const override { return effective_gamma_; } + void set_effective_gamma(int gamma) override { + effective_gamma_ = (gamma > 0) ? std::min(gamma, max_gamma()) : max_gamma(); + } + int hidden_size() const override; + bool attach(DFlashTarget * target) override; + void reset_chain() override; + void shutdown() override; + + // ── INativeMtp ───────────────────────────────────────────────── + int num_heads() const override; + bool step_batch(int32_t current_token, + int base_pos, + std::vector & out) override; + void set_draft_topk(int k) override; + + // Autoregressive chain draft: runs the head `chain_depth` times, + // feeding the previous iteration's post-shared_head_norm hidden as + // h_prev for the next. Per-iter step graphs are cached in + // state_->step_sg_cache. On the CPU stub path + // (no backend) this falls back to the default `step_batch`+clamp. + bool step_chain(int32_t current_token, + int base_pos, + int chain_depth, + std::vector & out) override; + + // Receive the backbone's final post-norm hidden state for the last committed + // token. Called by the chain runner before each step_batch(). The pointer + // and dim must remain valid for the duration of step_batch(); the module + // does NOT copy or own this buffer in PR 2c-bis (deferred to 2d-bis when + // the forward actually uses it). + void set_initial_hidden(const float * h_prev, int dim) override; + + // Pre-warm the head's K/V cache by running the head's K/V projections on + // every prefill position. After this call, head_kv[0] contains valid K/V + // entries at slots [0, n_prompt-1]; step_batch's range attention can then + // attend to the full prompt context instead of seeing a single slot. + // + // Arguments: + // prompt_tokens : the prompt sequence, length n_prompt. + // n_prompt : number of prompt tokens. + // prefill_next : the backbone's argmax at the end of prefill — used as + // t_{n_prompt} for slot n_prompt-1 (per DeepSeek-V3 + // Eq 21 the head at index i reads Emb(t_{i+k})). + // hiddens : the backbone's per-position post-norm hidden states, + // laid out as [token_0_hidden, ..., token_{n_prompt-1}_hidden]. + // Each token's hidden is hidden_size() floats, F32. + // Returns false on dimension mismatch or tensor-dequant failure. + bool warm_head_kv(const int32_t * prompt_tokens, + int n_prompt, + int32_t prefill_next, + const float * hiddens) override; + + // Read-only access for tests / introspection. + const Qwen35MtpWeights & weights() const; + + // Inject pre-built weights for unit tests without a real GGUF file. + // The "_for_test" / "test_" prefix is the contract: production code uses init(). + void attach_weights_for_test(const Qwen35MtpWeights & w); + + // Test-only accessors: return the last set_initial_hidden state. + // Used by test_qwen35_mtp_step_unit; not for production paths. + const float * test_initial_hidden_ptr() const; + int test_initial_hidden_dim() const; + +private: + struct State; + std::unique_ptr state_; + int effective_gamma_ = 0; // 0 until set_effective_gamma; orchestrator MUST set before decode + + // GPU forward path (cgraph on backbone backend); falls back to the CPU + // path inside step_batch() when no CUDA backend is available. + bool step_batch_gpu_(int32_t current_token, + int base_pos, + std::vector & out); + + // Bug #5 fix: graphs are shape-only, keyed on (head_idx, fa_window, + // fused_lm_head, topk_k). Per-call slot routing (write idx, read + // idxs, mask) is pushed as runtime tensor inputs by the caller. + struct Qwen35MtpStepGraph * get_or_build_step_graph_(int head_idx); +}; + +} // namespace mtp +} // namespace dflash::common diff --git a/dflash/src/qwen35/qwen35_mtp_graph.cpp b/dflash/src/qwen35/qwen35_mtp_graph.cpp new file mode 100644 index 00000000..56a5ba69 --- /dev/null +++ b/dflash/src/qwen35/qwen35_mtp_graph.cpp @@ -0,0 +1,404 @@ +// qwen35_mtp_graph.cpp — CUDA cgraph builder for Qwen3.6 MTP head's +// per-step forward. Mirrors the backbone's full-attention TRMBlock shape +// (qwen35_target_graph.cpp:build_full_attn_block) using ggml ops on the +// backbone's CUDA backend, so the head's matmuls use ggml's quant-aware +// MMQ kernels (matching backbone precision) instead of a CPU fp32 forward +// that drifts on Q2_K / Q3_K weights. + +#include "qwen35_mtp_graph.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#include +#include + +namespace dflash::common { +namespace mtp { + +static ggml_tensor * rms_norm_mul(ggml_context * ctx, ggml_tensor * x, + ggml_tensor * weight, float eps) { + ggml_tensor * n = ggml_rms_norm(ctx, x, eps); + return ggml_mul(ctx, n, weight); +} + +void qwen35_mtp_step_graph_free(Qwen35MtpStepGraph & sg) { + if (sg.alloc) { + ggml_gallocr_free(sg.alloc); + sg.alloc = nullptr; + } + if (sg.ctx) { + ggml_free(sg.ctx); + sg.ctx = nullptr; + } + sg.gf = nullptr; + sg.inp_embed = nullptr; + sg.inp_h_prev = nullptr; + sg.inp_pos = nullptr; + sg.inp_kv_idx_write = nullptr; + sg.inp_kv_idxs_read = nullptr; + sg.inp_kv_mask = nullptr; + sg.out_x_normed = nullptr; + sg.out_h_pre_norm = nullptr; + sg.out_argmax_token = nullptr; + sg.out_logits = nullptr; + sg.fa_window = 0; + sg.fa_max = 0; + sg.topk_k = 0; + sg.fused_lm_head = false; +} + +void qwen35_mtp_warm_graph_free(Qwen35MtpWarmGraph & sg) { + if (sg.alloc) { + ggml_gallocr_free(sg.alloc); + sg.alloc = nullptr; + } + if (sg.ctx) { + ggml_free(sg.ctx); + sg.ctx = nullptr; + } + sg.gf = nullptr; + sg.inp_embed_seq = nullptr; + sg.inp_h_seq = nullptr; + sg.inp_pos = nullptr; +} + +bool build_qwen35_mtp_warm_graph( + Qwen35MtpWarmGraph & sg, + const Qwen35MtpHeadWeights & head, + ggml_tensor * head_k_cache, + ggml_tensor * head_v_cache, + ggml_backend_t backend, + int n_embd, + int n_head_kv, + int key_len, + int val_len, + int n_rot, + int rope_sections[4], + float rope_freq_base, + float rms_eps, + int slot_start, + int n_tokens) { + qwen35_mtp_warm_graph_free(sg); + if (n_tokens <= 0) return false; + + const int head_dim = key_len; // qwen3.6 has key_length == value_length + + ggml_init_params ip{}; + // n_tokens up to a few hundred prompt tokens; concat + matmul of size + // (kv_dim, n_embd) needs scratch. 256 MB headroom is plenty. + ip.mem_size = 256 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + sg.ctx = ggml_init(ip); + if (!sg.ctx) return false; + sg.gf = ggml_new_graph_custom(sg.ctx, 4096, false); + + // Inputs. + sg.inp_embed_seq = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_name(sg.inp_embed_seq, "mtp_warm_embed_seq"); + ggml_set_input(sg.inp_embed_seq); + + sg.inp_h_seq = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_name(sg.inp_h_seq, "mtp_warm_h_seq"); + ggml_set_input(sg.inp_h_seq); + + sg.inp_pos = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_I32, 4 * n_tokens); + ggml_set_name(sg.inp_pos, "mtp_warm_pos"); + ggml_set_input(sg.inp_pos); + + // Eq 21: eh_proj([e_norm; h_norm]) — concat order matches llama.cpp PR #22673. + ggml_tensor * e_in = rms_norm_mul(sg.ctx, sg.inp_embed_seq, head.enorm, rms_eps); + ggml_tensor * h_in = rms_norm_mul(sg.ctx, sg.inp_h_seq, head.hnorm, rms_eps); + ggml_tensor * cat = ggml_concat(sg.ctx, e_in, h_in, /*dim=*/0); + ggml_tensor * x = ggml_mul_mat(sg.ctx, head.eh_proj, cat); // [n_embd, n_tokens] + + // Pre-attn norm. + ggml_tensor * cur = rms_norm_mul(sg.ctx, x, head.attn_norm, rms_eps); + + // K, V projections only (no Q, no attention, no FFN). + ggml_tensor * Kcur = ggml_mul_mat(sg.ctx, head.attn_k, cur); // [kv_dim, n_tokens] + Kcur = ggml_reshape_3d(sg.ctx, Kcur, head_dim, n_head_kv, n_tokens); + Kcur = rms_norm_mul(sg.ctx, Kcur, head.attn_k_norm, rms_eps); + + ggml_tensor * Vcur = ggml_mul_mat(sg.ctx, head.attn_v, cur); + Vcur = ggml_reshape_3d(sg.ctx, Vcur, head_dim, n_head_kv, n_tokens); + + Kcur = ggml_rope_multi(sg.ctx, Kcur, sg.inp_pos, /*freq_factors=*/nullptr, + n_rot, rope_sections, GGML_ROPE_TYPE_MROPE, + /*n_ctx_orig=*/0, rope_freq_base, 1.0f, + 0.0f, 1.0f, 0.0f, 0.0f); + + // Permute to [head_dim, n_tokens, n_head_kv] so cpy maps element-wise + // into the cache view (slot range along dim 1). + ggml_tensor * K_T = ggml_permute(sg.ctx, Kcur, 0, 2, 1, 3); + ggml_tensor * V_T = ggml_permute(sg.ctx, Vcur, 0, 2, 1, 3); + + ggml_tensor * k_dst = ggml_view_3d(sg.ctx, head_k_cache, + head_dim, n_tokens, n_head_kv, + head_k_cache->nb[1], head_k_cache->nb[2], + head_k_cache->nb[1] * slot_start); + ggml_tensor * v_dst = ggml_view_3d(sg.ctx, head_v_cache, + head_dim, n_tokens, n_head_kv, + head_v_cache->nb[1], head_v_cache->nb[2], + head_v_cache->nb[1] * slot_start); + + ggml_build_forward_expand(sg.gf, ggml_cpy(sg.ctx, K_T, k_dst)); + ggml_build_forward_expand(sg.gf, ggml_cpy(sg.ctx, V_T, v_dst)); + + sg.alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!sg.alloc) { + std::fprintf(stderr, "[qwen35_mtp_warm_graph] gallocr_new failed\n"); + return false; + } + if (!ggml_gallocr_alloc_graph(sg.alloc, sg.gf)) { + std::fprintf(stderr, "[qwen35_mtp_warm_graph] alloc_graph failed\n"); + return false; + } + return true; +} + +bool build_qwen35_mtp_step_graph( + Qwen35MtpStepGraph & sg, + const Qwen35MtpHeadWeights & head, + ggml_tensor * head_k_cache, + ggml_tensor * head_v_cache, + ggml_backend_t backend, + int n_embd, + int n_head, + int n_head_kv, + int key_len, + int val_len, + int ffn_len, + int n_rot, + int rope_sections[4], + float rope_freq_base, + float rms_eps, + int n_ctx, + int fa_window, + ggml_tensor * lm_head_weight, + int lm_head_topk) { + qwen35_mtp_step_graph_free(sg); + + const int q_dim = n_head * key_len; + const int head_dim = key_len; // qwen3.6 has key_length == value_length + const int fa_max = (fa_window > 0 && fa_window < n_ctx) ? fa_window : n_ctx; + + ggml_init_params ip{}; + ip.mem_size = 128 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + sg.ctx = ggml_init(ip); + if (!sg.ctx) return false; + sg.gf = ggml_new_graph_custom(sg.ctx, 2048, false); + + // ─── Inputs ──────────────────────────────────────────────────── + sg.inp_embed = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_F32, n_embd); + ggml_set_name(sg.inp_embed, "mtp_inp_embed"); + ggml_set_input(sg.inp_embed); + + sg.inp_h_prev = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_F32, n_embd); + ggml_set_name(sg.inp_h_prev, "mtp_inp_h_prev"); + ggml_set_input(sg.inp_h_prev); + + // MROPE expects positions as [n_tokens * 4] i32 (4 axes per token). + sg.inp_pos = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_I32, 4); + ggml_set_name(sg.inp_pos, "mtp_inp_pos"); + ggml_set_input(sg.inp_pos); + + // Runtime KV slot routing (bug #5: avoid baking draft_pos into views). + sg.inp_kv_idx_write = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_I64, 1); + ggml_set_name(sg.inp_kv_idx_write, "mtp_inp_kv_idx_write"); + ggml_set_input(sg.inp_kv_idx_write); + + sg.inp_kv_idxs_read = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_I32, fa_max, n_head_kv); + ggml_set_name(sg.inp_kv_idxs_read, "mtp_inp_kv_idxs_read"); + ggml_set_input(sg.inp_kv_idxs_read); + + sg.inp_kv_mask = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F16, fa_max, 1); + ggml_set_name(sg.inp_kv_mask, "mtp_inp_kv_mask"); + ggml_set_input(sg.inp_kv_mask); + + // ─── Eq 21: eh_proj([hnorm(h_prev); enorm(embed)]) ──────────── + ggml_tensor * e_in = rms_norm_mul(sg.ctx, sg.inp_embed, head.enorm, rms_eps); + ggml_tensor * h_in = rms_norm_mul(sg.ctx, sg.inp_h_prev, head.hnorm, rms_eps); + // Concat order [e_in; h_in] (embed first, hidden second) matches the + // reference impl in llama.cpp PR #22673 (graph_mtp in src/models/qwen35.cpp: + // `ggml_concat(ctx0, e_norm, h_norm, 0)`). The earlier redesign doc had + // the order backwards; with the wrong order, eh_proj receives swapped + // halves of its trained input and the head produces useless logits. + ggml_tensor * cat = ggml_concat(sg.ctx, e_in, h_in, /*dim=*/0); + ggml_tensor * x = ggml_mul_mat(sg.ctx, head.eh_proj, cat); + // x: [n_embd] + + // ─── TRMBlock: pre-attn norm ───────────────────────────────── + ggml_tensor * cur = rms_norm_mul(sg.ctx, x, head.attn_norm, rms_eps); + // cur: [n_embd] + + // ─── Q/gate projection (packed [Q; gate] per head) ─────────── + // attn_q is [n_embd, 2*q_dim]; output is [2*q_dim]. + ggml_tensor * qg = ggml_mul_mat(sg.ctx, head.attn_q, cur); + // Reshape to [head_dim*2, n_head] to split Q (offset 0) and gate (offset head_dim). + qg = ggml_reshape_2d(sg.ctx, qg, head_dim * 2, n_head); + + // Q half: shape [head_dim, n_head], stride head_dim*2 between heads. + ggml_tensor * Q = ggml_view_2d(sg.ctx, qg, + head_dim, n_head, + ggml_element_size(qg) * head_dim * 2, + /*offset=*/0); + Q = rms_norm_mul(sg.ctx, Q, head.attn_q_norm, rms_eps); + + // gate half: same shape, offset head_dim. + ggml_tensor * gate = ggml_view_2d(sg.ctx, qg, + head_dim, n_head, + ggml_element_size(qg) * head_dim * 2, + ggml_element_size(qg) * head_dim); + gate = ggml_cont_2d(sg.ctx, gate, q_dim, 1); // [q_dim, 1] + + // ─── K / V projections ───────────────────────────────────── + ggml_tensor * Kcur = ggml_mul_mat(sg.ctx, head.attn_k, cur); + Kcur = ggml_reshape_2d(sg.ctx, Kcur, head_dim, n_head_kv); + Kcur = rms_norm_mul(sg.ctx, Kcur, head.attn_k_norm, rms_eps); + + ggml_tensor * Vcur = ggml_mul_mat(sg.ctx, head.attn_v, cur); + Vcur = ggml_reshape_2d(sg.ctx, Vcur, head_dim, n_head_kv); + + // ─── MROPE on Q and K (4-axis positions, sections [11,11,10,0]) ── + // rope_multi expects [n_dims, n_head, n_tokens, 1]. + ggml_tensor * Q3 = ggml_reshape_3d(sg.ctx, Q, head_dim, n_head, 1); + Q3 = ggml_rope_multi(sg.ctx, Q3, sg.inp_pos, /*freq_factors=*/nullptr, + n_rot, rope_sections, GGML_ROPE_TYPE_MROPE, + /*n_ctx_orig=*/0, rope_freq_base, 1.0f, + 0.0f, 1.0f, 0.0f, 0.0f); + ggml_tensor * K3 = ggml_reshape_3d(sg.ctx, Kcur, head_dim, n_head_kv, 1); + K3 = ggml_rope_multi(sg.ctx, K3, sg.inp_pos, nullptr, + n_rot, rope_sections, GGML_ROPE_TYPE_MROPE, + 0, rope_freq_base, 1.0f, + 0.0f, 1.0f, 0.0f, 0.0f); + + // ─── Write Kcur/Vcur at runtime slot (inp_kv_idx_write) ────── + // Bug #5: per-slot views can't be built at graph-build time (offset + // must be static). Use ggml_set_rows on a 3D cache view; b is the + // F32 K/V with shape [head_dim, 1, n_head_kv]; c is i64[1] broadcast + // across the head_kv axis (b->ne[2] % c->ne[1] == 0 satisfies the + // broadcast rule, so all heads write the same slot). + // K3 is post-RoPE [head_dim, n_head_kv, 1]; reshape to [head_dim,1,n_head_kv] + // so set_rows sees ne[1]=1 (==c.ne[0]) and broadcasts the i64[1] index over + // the n_head_kv axis. + ggml_tensor * K_b = ggml_reshape_3d(sg.ctx, K3, head_dim, 1, n_head_kv); + ggml_tensor * V_b = ggml_reshape_3d(sg.ctx, Vcur, head_dim, 1, n_head_kv); + ggml_tensor * k_after = ggml_set_rows(sg.ctx, head_k_cache, K_b, sg.inp_kv_idx_write); + ggml_tensor * v_after = ggml_set_rows(sg.ctx, head_v_cache, V_b, sg.inp_kv_idx_write); + ggml_build_forward_expand(sg.gf, k_after); + ggml_build_forward_expand(sg.gf, v_after); + + // ─── Flash attention over runtime-selected slots ───────────── + // Read fa_max rows per head via ggml_get_rows. Indices live in + // inp_kv_idxs_read [fa_max, n_head_kv]; rows past live kv_len are + // gathered from slot 0 then masked to -INF via inp_kv_mask. Read + // from k_after / v_after so the DAG sees the set_rows write as a + // dependency (set_rows returns view(a) so direct dep chaining works). + ggml_tensor * Qfa = ggml_permute(sg.ctx, Q3, 0, 2, 1, 3); + Qfa = ggml_cont(sg.ctx, Qfa); + + ggml_tensor * Kfa = ggml_get_rows(sg.ctx, k_after, sg.inp_kv_idxs_read); + ggml_tensor * Vfa = ggml_get_rows(sg.ctx, v_after, sg.inp_kv_idxs_read); + + const float kq_scale = 1.0f / std::sqrt((float)head_dim); + ggml_tensor * attn = ggml_flash_attn_ext(sg.ctx, Qfa, Kfa, Vfa, + sg.inp_kv_mask, kq_scale, 0.0f, 0.0f); + // attn: [head_dim, n_head, n_tokens=1] (permuted output of FA) + attn = ggml_reshape_2d(sg.ctx, attn, q_dim, 1); + + // ─── Sigmoid gate ──────────────────────────────────────────── + ggml_tensor * gate_sig = ggml_sigmoid(sg.ctx, gate); + attn = ggml_mul(sg.ctx, attn, gate_sig); + + // ─── Output projection + residual ──────────────────────────── + ggml_tensor * attn_out = ggml_mul_mat(sg.ctx, head.attn_output, attn); + // attn_out: [n_embd]; flatten x and attn_out to same rank. + attn_out = ggml_reshape_1d(sg.ctx, attn_out, n_embd); + x = ggml_add(sg.ctx, x, attn_out); + + // ─── Post-attn norm + SwiGLU FFN ───────────────────────────── + cur = rms_norm_mul(sg.ctx, x, head.post_attention_norm, rms_eps); + ggml_tensor * ffn_g = ggml_mul_mat(sg.ctx, head.ffn_gate, cur); + ffn_g = ggml_silu(sg.ctx, ffn_g); + ggml_tensor * ffn_u = ggml_mul_mat(sg.ctx, head.ffn_up, cur); + ggml_tensor * ffn_gu = ggml_mul(sg.ctx, ffn_g, ffn_u); + ggml_tensor * ffn_out = ggml_mul_mat(sg.ctx, head.ffn_down, ffn_gu); + ffn_out = ggml_reshape_1d(sg.ctx, ffn_out, n_embd); + x = ggml_add(sg.ctx, x, ffn_out); + + // ─── Pre-shared_head_norm hidden (chain-state output) ──────── + // Per llama.cpp PR #22673 (`t_h_pre_norm` in src/models/qwen35.cpp), + // the hidden fed back as h_prev for the NEXT autoregressive step + // must be POST-residual-add but PRE-`shared_head_norm`. Feeding + // back `out_x_normed` (post-norm) double-normalises on the next + // iter's `hnorm` and compounds rejection per depth. See the + // CPU-path reference at qwen35_mtp.cpp:1166 (`last_hidden = x`). + ggml_set_name(x, "mtp_out_h_pre_norm"); + ggml_set_output(x); + sg.out_h_pre_norm = x; + ggml_build_forward_expand(sg.gf, sg.out_h_pre_norm); + + // ─── Shared head norm ──────────────────────────────────────── + ggml_tensor * out = head.shared_head_norm + ? rms_norm_mul(sg.ctx, x, head.shared_head_norm, rms_eps) + : ggml_rms_norm(sg.ctx, x, rms_eps); + ggml_set_name(out, "mtp_out_x_normed"); + ggml_set_output(out); + sg.out_x_normed = out; + + ggml_build_forward_expand(sg.gf, sg.out_x_normed); + + // ─── Fused LM-head projection (optional) ───────────────────── + // When the caller passes lm_head_weight non-null we append the LM + // head matmul + argmax directly to this graph so step_batch_gpu_ can + // avoid the hidden -> CPU -> separate projection-graph round trip + // that dominates per-step latency (see qwen35_dflash_target.cpp: + // project_hidden_to_tokens for the unfused path). + if (lm_head_weight) { + // out is shape [n_embd]; reshape to [n_embd, 1] so mul_mat against + // the [n_embd, n_vocab] weight matches the LM-head projection + // step graph (graph_builders.cpp:251). + ggml_tensor * x_for_lm = ggml_reshape_2d(sg.ctx, out, n_embd, 1); + ggml_tensor * logits = ggml_mul_mat(sg.ctx, lm_head_weight, x_for_lm); + ggml_set_name(logits, "mtp_fused_logits"); + if (lm_head_topk > 0) { + // Surface raw logits so the host can run log-softmax for + // top-K without re-running the matmul. For K=1 we skip this + // and download just the argmax to keep transfer minimal. + ggml_set_output(logits); + sg.out_logits = logits; + ggml_build_forward_expand(sg.gf, logits); + } + ggml_tensor * argmax = ggml_argmax(sg.ctx, logits); + ggml_set_name(argmax, "mtp_fused_argmax"); + ggml_set_output(argmax); + sg.out_argmax_token = argmax; + ggml_build_forward_expand(sg.gf, argmax); + } + + // ─── Allocate ──────────────────────────────────────────────── + sg.alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!sg.alloc) { + std::fprintf(stderr, "[qwen35_mtp_graph] ggml_gallocr_new failed\n"); + return false; + } + if (!ggml_gallocr_alloc_graph(sg.alloc, sg.gf)) { + std::fprintf(stderr, "[qwen35_mtp_graph] ggml_gallocr_alloc_graph failed\n"); + return false; + } + + // Record build keys for cache invalidation in Qwen35MtpModule. + sg.fa_window = fa_window; + sg.fa_max = fa_max; + sg.topk_k = lm_head_topk; + sg.fused_lm_head = (lm_head_weight != nullptr); + return true; +} + +} // namespace mtp +} // namespace dflash::common diff --git a/dflash/src/qwen35/qwen35_mtp_graph.h b/dflash/src/qwen35/qwen35_mtp_graph.h new file mode 100644 index 00000000..8125beef --- /dev/null +++ b/dflash/src/qwen35/qwen35_mtp_graph.h @@ -0,0 +1,133 @@ +// qwen35_mtp_graph.h — CUDA cgraph for Qwen3.6 MTP head step forward. + +#pragma once + +#include "qwen35_mtp.h" + +struct ggml_tensor; +struct ggml_context; +struct ggml_cgraph; +struct ggml_gallocr; +typedef struct ggml_gallocr * ggml_gallocr_t; +struct ggml_backend; +typedef struct ggml_backend * ggml_backend_t; + +namespace dflash::common { +namespace mtp { + +// Step graph state for a single MTP head's per-call forward. Inputs are +// set via ggml_backend_tensor_set; the output (out_x_normed) is read back +// to host (or, when the fused LM-head path is wired, out_argmax_token / +// out_topk_* are read instead and out_x_normed is unused). +// +// Bug #5 fix: graphs are shape-only (no slot/kv_len baked in). The KV +// slot to write and the FA read indices + mask are runtime inputs, so +// one graph per (head, fa_window, fused_lm_head, topk_k) suffices for +// the entire generation instead of one per draft_pos. +struct Qwen35MtpStepGraph { + ggml_context * ctx = nullptr; + ggml_cgraph * gf = nullptr; + ggml_gallocr_t alloc = nullptr; + + // Build-time keys (used to detect invalidation). + int fa_window = 0; + int fa_max = 0; // baked FA window width (rows in inp_kv_idxs_read/inp_kv_mask) + int topk_k = 0; + bool fused_lm_head = false; + + // Inputs. + ggml_tensor * inp_embed = nullptr; // [n_embd] f32 — pre-embedded cur_token + ggml_tensor * inp_h_prev = nullptr; // [n_embd] f32 — backbone hidden h_{base_pos-1} + ggml_tensor * inp_pos = nullptr; // [4] i32 — MROPE positions (p,p,p,0) + ggml_tensor * inp_kv_idx_write= nullptr; // [1] i64 — slot to write Kcur/Vcur + ggml_tensor * inp_kv_idxs_read= nullptr; // [fa_max,n_head_kv] i32 — FA read slots + ggml_tensor * inp_kv_mask = nullptr; // [fa_max,1] f16 — -INF on inactive rows + + // Outputs. + ggml_tensor * out_x_normed = nullptr; // [n_embd] f32 — post shared_head_norm hidden + // Pre-shared_head_norm hidden (post FFN residual `add`). Mirrors + // llama.cpp PR #22673's `t_h_pre_norm`: this is the tensor that must + // be fed back as h_prev for the NEXT autoregressive step. Re-using + // `out_x_normed` here causes the next iter's `hnorm` to double- + // normalise, producing a distribution drift that compounds per depth + // (D=3 accept collapsed from ~91% to ~67% per-position in our bench; + // see qwen35_mtp.cpp:1166 for the byte-correct CPU stash of pre-norm). + ggml_tensor * out_h_pre_norm = nullptr; // [n_embd] f32 — pre shared_head_norm hidden + // Fused LM-head outputs (only populated when build is called with a + // non-null lm_head_weight). out_argmax_token holds the i32 argmax of + // the logits; out_logits is exposed so the host can compute log-softmax + // for top-K emission without re-running the LM head matmul. + ggml_tensor * out_argmax_token = nullptr; // [1] i32 + ggml_tensor * out_logits = nullptr; // [n_vocab] f32 — full logits (optional) +}; + +void qwen35_mtp_step_graph_free(Qwen35MtpStepGraph & sg); + +// Batched warmup graph: writes K/V to head_kv slots [slot_start, slot_start+n_tokens) +// in a single backend pass. Replaces the host-side per-position CPU loop with +// one ggml cgraph using the same quant-aware matmul kernels the step graph uses. +struct Qwen35MtpWarmGraph { + ggml_context * ctx = nullptr; + ggml_cgraph * gf = nullptr; + ggml_gallocr_t alloc = nullptr; + + ggml_tensor * inp_embed_seq = nullptr; // [n_embd, n_tokens] f32 — pre-embedded tokens + ggml_tensor * inp_h_seq = nullptr; // [n_embd, n_tokens] f32 — backbone hiddens + ggml_tensor * inp_pos = nullptr; // [4 * n_tokens] i32 — MROPE positions +}; + +void qwen35_mtp_warm_graph_free(Qwen35MtpWarmGraph & sg); + +// Build the warmup graph for n_tokens prefill positions writing to slots +// [slot_start, slot_start + n_tokens) of head_k_cache / head_v_cache. +bool build_qwen35_mtp_warm_graph( + Qwen35MtpWarmGraph & sg, + const Qwen35MtpHeadWeights & head, + ggml_tensor * head_k_cache, + ggml_tensor * head_v_cache, + ggml_backend_t backend, + int n_embd, + int n_head_kv, + int key_len, + int val_len, + int n_rot, + int rope_sections[4], + float rope_freq_base, + float rms_eps, + int slot_start, + int n_tokens); + +// Build the head's step graph as a SHAPE-ONLY template. Per-call slot +// (write) and FA read indices/mask are wired as runtime inputs so a single +// graph services every draft_pos. +// - head_k_cache / head_v_cache: per-head KV cache tensors on the backbone +// backend; layout [head_dim, n_ctx, n_head_kv]. F16 on device. +// - n_ctx: cache row count (used to reshape the cache for set_rows/get_rows). +// - fa_window: if > 0, FA attends fa_max=min(fa_window,n_ctx) rows. If <=0, +// fa_max=n_ctx (full context). Rows beyond live kv_len are masked at +// runtime via inp_kv_mask. +// - lm_head_weight / lm_head_topk: see prior comment block; behaviour unchanged. +// Returns false on allocation failure. +bool build_qwen35_mtp_step_graph( + Qwen35MtpStepGraph & sg, + const Qwen35MtpHeadWeights & head, + ggml_tensor * head_k_cache, + ggml_tensor * head_v_cache, + ggml_backend_t backend, + int n_embd, + int n_head, + int n_head_kv, + int key_len, + int val_len, + int ffn_len, + int n_rot, + int rope_sections[4], + float rope_freq_base, + float rms_eps, + int n_ctx, + int fa_window = 0, + ggml_tensor * lm_head_weight = nullptr, + int lm_head_topk = 0); + +} // namespace mtp +} // namespace dflash::common diff --git a/dflash/src/qwen35/qwen35_mtp_loader.cpp b/dflash/src/qwen35/qwen35_mtp_loader.cpp new file mode 100644 index 00000000..e6f449e0 --- /dev/null +++ b/dflash/src/qwen35/qwen35_mtp_loader.cpp @@ -0,0 +1,225 @@ +// qwen35_mtp_loader.cpp — Discovery loader for Qwen3.6 -MTP-GGUF files. +// +// Parses the GGUF tensor inventory for a Qwen3.6 MTP assistant model +// (Unsloth-style -MTP-GGUF, one or more MTP heads sharing the backbone's +// embedding + LM-head weights). Resolves donor layers and head dimensions +// from the GGUF metadata, populates MtpDrafterWeights for the runtime +// graph, and returns a ggml_context that owns the head tensors. +// +// The architecture string in general.architecture is 'qwen35' — same as +// the backbone. The MTP heads live at blk..* (one past the last +// backbone block), and discovery is keyed off the nextn_predict_layers +// metadata field. + +#include "qwen35_mtp.h" + +#include "common/gguf_metadata.h" +#include "ggml.h" +#include "gguf.h" + +#include +#include +#include +#include +#include + +namespace dflash::common::mtp { + +namespace { + +// Bind a tensor pointer by exact name lookup on the ggml_context that +// gguf_init_from_file built. Returns nullptr if absent. +ggml_tensor * find_tensor(ggml_context * ctx, const std::string & name) { + return ggml_get_tensor(ctx, name.c_str()); +} + +} // anonymous + +bool load_qwen35_mtp_weights(const std::string & gguf_path, + ggml_context * ctx, + int expected_n_embd, + int expected_n_vocab, + Qwen35MtpWeights & out_weights, + std::string & out_error) { + out_error.clear(); + + struct gguf_init_params gp{}; + gp.no_alloc = true; + gp.ctx = nullptr; + struct gguf_context * gguf = gguf_init_from_file(gguf_path.c_str(), gp); + if (!gguf) { + out_error = "qwen35_mtp_loader: gguf_init_from_file failed for " + gguf_path; + return false; + } + + // ── Arch and dimensions ──────────────────────────────────────────── + { + std::string arch_err; + if (!dflash::common::gguf_check_architecture(gguf, "qwen35", arch_err)) { + out_error = "qwen35_mtp_loader: " + arch_err; + gguf_free(gguf); + return false; + } + } + out_weights.backbone_arch = "qwen35"; + out_weights.base_model_name = dflash::common::gguf_get_str_or(gguf, "general.name", ""); + + uint32_t n_layer = 0; + if (!dflash::common::gguf_require_u32(gguf, "qwen35.block_count", n_layer, out_error)) { + out_error = "qwen35_mtp_loader: " + out_error; + gguf_free(gguf); + return false; + } + uint32_t n_embd_meta = 0; + if (!dflash::common::gguf_require_u32(gguf, "qwen35.embedding_length", n_embd_meta, out_error)) { + out_error = "qwen35_mtp_loader: " + out_error; + gguf_free(gguf); + return false; + } + if ((int)n_embd_meta != expected_n_embd) { + char msg[256]; + std::snprintf(msg, sizeof msg, + "qwen35_mtp_loader: backbone n_embd mismatch (gguf=%u, expected=%d)", + n_embd_meta, expected_n_embd); + out_error = msg; + gguf_free(gguf); + return false; + } + (void)expected_n_vocab; // vocab cross-check folded into the optional + // embed_tokens / shared_head_head shape check + + // ── NextN head count ────────────────────────────────────────────── + uint32_t n_heads = 0; + if (!dflash::common::gguf_require_u32(gguf, "qwen35.nextn_predict_layers", n_heads, out_error)) { + out_error = "qwen35_mtp_loader: GGUF lacks qwen35.nextn_predict_layers — this is not an MTP variant"; + gguf_free(gguf); + return false; + } + if (n_heads == 0) { + out_error = "qwen35_mtp_loader: nextn_predict_layers=0 — no MTP heads in this file"; + gguf_free(gguf); + return false; + } + if (n_heads > (uint32_t)n_layer) { + char msg[256]; + std::snprintf(msg, sizeof msg, + "qwen35_mtp_loader: nextn_predict_layers=%u > n_layer=%u (corrupt metadata)", + n_heads, n_layer); + out_error = msg; + gguf_free(gguf); + return false; + } + + // ── Attention sizing metadata ───────────────────────────────────────── + uint32_t n_head_count = 0; + if (!dflash::common::gguf_require_u32(gguf, "qwen35.attention.head_count", n_head_count, out_error)) { + out_error = "qwen35_mtp_loader: " + out_error; + gguf_free(gguf); + return false; + } + uint32_t n_head_kv = 0; + if (!dflash::common::gguf_require_u32(gguf, "qwen35.attention.head_count_kv", n_head_kv, out_error)) { + out_error = "qwen35_mtp_loader: " + out_error; + gguf_free(gguf); + return false; + } + uint32_t n_key_length = 0; + if (!dflash::common::gguf_require_u32(gguf, "qwen35.attention.key_length", n_key_length, out_error)) { + out_error = "qwen35_mtp_loader: " + out_error; + gguf_free(gguf); + return false; + } + uint32_t n_value_length = 0; + if (!dflash::common::gguf_require_u32(gguf, "qwen35.attention.value_length", n_value_length, out_error)) { + out_error = "qwen35_mtp_loader: " + out_error; + gguf_free(gguf); + return false; + } + uint32_t n_ffn_length = 0; + if (!dflash::common::gguf_require_u32(gguf, "qwen35.feed_forward_length", n_ffn_length, out_error)) { + out_error = "qwen35_mtp_loader: " + out_error; + gguf_free(gguf); + return false; + } + + out_weights.n_embd = (int)n_embd_meta; + out_weights.n_vocab = expected_n_vocab; + out_weights.n_heads = (int)n_heads; + out_weights.n_backbone_layers = (int)n_layer; + out_weights.n_head_count = (int)n_head_count; + out_weights.n_head_kv = (int)n_head_kv; + out_weights.n_key_length = (int)n_key_length; + out_weights.n_value_length = (int)n_value_length; + out_weights.n_ffn_length = (int)n_ffn_length; + out_weights.heads.clear(); + out_weights.heads.resize(n_heads); + + // ── Bind per-head tensor pointers ───────────────────────────────── + // Heads occupy the last `n_heads` block indices: [n_layer - n_heads, + // n_layer - 1]. For each, the required tensors are eh_proj/enorm/ + // hnorm; embed_tokens / shared_head_head / shared_head_norm are + // optional (caller falls back to backbone tensors when absent). + int missing_required = 0; + for (int h = 0; h < (int)n_heads; h++) { + const int layer_idx = (int)n_layer - (int)n_heads + h; + auto & head = out_weights.heads[h]; + head.layer_idx = layer_idx; + + auto bind = [&](const char * base, bool required) -> ggml_tensor * { + char name[256]; + std::snprintf(name, sizeof name, "blk.%d.%s.weight", layer_idx, base); + ggml_tensor * t = find_tensor(ctx, name); + if (!t && required) { + std::fprintf(stderr, + "[qwen35_mtp_loader] missing required tensor: %s\n", name); + missing_required++; + } + return t; + }; + head.eh_proj = bind("nextn.eh_proj", /*required=*/true); + head.enorm = bind("nextn.enorm", /*required=*/true); + head.hnorm = bind("nextn.hnorm", /*required=*/true); + head.embed_tokens = bind("nextn.embed_tokens", /*required=*/false); + head.shared_head_head = bind("nextn.shared_head_head", /*required=*/false); + head.shared_head_norm = bind("nextn.shared_head_norm", /*required=*/false); + + // Shape B: head-owned transformer-block tensors (required for all heads). + // These live at blk.{layer_idx}.{name}.weight (no nextn. prefix). + auto bind_blk = [&](const char * name, bool required) -> ggml_tensor * { + char full[256]; + std::snprintf(full, sizeof full, "blk.%d.%s.weight", layer_idx, name); + ggml_tensor * t = find_tensor(ctx, full); + if (!t && required) { + std::fprintf(stderr, + "[qwen35_mtp_loader] missing required tensor: %s\n", full); + missing_required++; + } + return t; + }; + head.attn_norm = bind_blk("attn_norm", /*required=*/true); + head.attn_q = bind_blk("attn_q", /*required=*/true); + head.attn_q_norm = bind_blk("attn_q_norm", /*required=*/true); + head.attn_k = bind_blk("attn_k", /*required=*/true); + head.attn_k_norm = bind_blk("attn_k_norm", /*required=*/true); + head.attn_v = bind_blk("attn_v", /*required=*/true); + head.attn_output = bind_blk("attn_output", /*required=*/true); + head.post_attention_norm = bind_blk("post_attention_norm", /*required=*/true); + head.ffn_gate = bind_blk("ffn_gate", /*required=*/true); + head.ffn_up = bind_blk("ffn_up", /*required=*/true); + head.ffn_down = bind_blk("ffn_down", /*required=*/true); + } + + gguf_free(gguf); + + if (missing_required > 0) { + char msg[256]; + std::snprintf(msg, sizeof msg, + "qwen35_mtp_loader: %d required NextN tensor(s) missing — context likely lacks the MTP tensors. Did the backbone loader allocate them?", + missing_required); + out_error = msg; + return false; + } + return true; +} + +} // namespace dflash::common::mtp diff --git a/dflash/src/qwen35/qwen35_target_graph.cpp b/dflash/src/qwen35/qwen35_target_graph.cpp index fdb3a914..09eda946 100644 --- a/dflash/src/qwen35/qwen35_target_graph.cpp +++ b/dflash/src/qwen35/qwen35_target_graph.cpp @@ -880,8 +880,27 @@ static ggml_tensor * build_delta_net_block( S_v * S_v * r_elt, S_v * S_v * H_v * r_elt, inter_offset); + // The persistent cache buffer is sized for max_verify_tokens slots + // along its last dim (e.g. 16); chain verify may feed fewer than that + // (e.g. g+1 tokens, where g = max_gamma). ggml_cpy requires matching + // nelements, so slice the dst down to n_seq_tokens. Caller (gate in + // qwen35_dflash_target.cpp::verify_batch) ensures n_seq_tokens <= + // cap->ssm_intermediate_states->ne[3] before enabling capture, so the + // view never overflows the destination buffer. + ggml_tensor * inter_dst = cap->ssm_intermediate_states; + if ((int64_t)n_seq_tokens != inter_dst->ne[3]) { + inter_dst = ggml_view_4d(ctx, cap->ssm_intermediate_states, + cap->ssm_intermediate_states->ne[0], + cap->ssm_intermediate_states->ne[1], + cap->ssm_intermediate_states->ne[2], + n_seq_tokens, + cap->ssm_intermediate_states->nb[1], + cap->ssm_intermediate_states->nb[2], + cap->ssm_intermediate_states->nb[3], + /*offset=*/0); + } ggml_build_forward_expand(gf, - ggml_cpy(ctx, inter_view, cap->ssm_intermediate_states)); + ggml_cpy(ctx, inter_view, inter_dst)); } } // end of block started at `{` before `const int64_t S_v = head_v_dim;` @@ -1142,8 +1161,60 @@ QwenGraphOutputs build_qwen35_graph( } // 2. Final norm + // + // 2pre. Capture PRE-final-output-norm hidden (mirrors llama.cpp + // PR #22673's `t_h_pre_norm` from src/models/qwen35.cpp:208-211). + // The Qwen3.6 MTP head's hnorm normalises h_prev internally; if + // we feed it the post-output-norm hidden it double-normalises, + // compounding the per-depth rejection rate (see audit notes in + // qwen35_mtp.cpp:1743 / qwen35_mtp_graph.cpp:329). Only wired + // as a graph output when capture_all_norm_hidden is set — that + // flag is owned by the MTP module's adapter. + ggml_tensor * last_h_pre_norm = nullptr; + ggml_tensor * all_h_pre_norm = nullptr; + if (in.capture_all_norm_hidden) { + ggml_tensor * inpL_2d = ggml_reshape_2d(ctx, inpL, hidden, n_tokens); + last_h_pre_norm = ggml_view_2d(ctx, inpL_2d, hidden, 1, + inpL_2d->nb[1], + (size_t)(n_tokens - 1) * inpL_2d->nb[1]); + ggml_set_name(last_h_pre_norm, "last_h_pre_norm"); + ggml_set_output(last_h_pre_norm); + ggml_build_forward_expand(gf, last_h_pre_norm); + + all_h_pre_norm = inpL_2d; + ggml_set_name(all_h_pre_norm, "all_h_pre_norm"); + ggml_set_output(all_h_pre_norm); + ggml_build_forward_expand(gf, all_h_pre_norm); + } + ggml_tensor * out = rms_norm_mul(ctx, inpL, w.out_norm, w.rms_eps); + // 2a. Expose the last token's post-norm hidden as a named graph output. + // This is h_prev_0 for the Qwen3.6 MTP module (the backbone's final + // hidden state for the last committed token). We always view the last + // column of `out` regardless of whether last_token_logits_only is set, + // so the MTP module always receives exactly one [n_embd] slice. + // Empirical check: pre-norm capture (inpL) makes draft accept rate + // strictly worse on Qwen3.6-27B; post-norm is correct. + ggml_tensor * last_norm_hidden = ggml_view_2d(ctx, out, hidden, 1, + out->nb[1], + (size_t)(n_tokens - 1) * out->nb[1]); + ggml_set_name(last_norm_hidden, "last_norm_hidden"); + ggml_set_output(last_norm_hidden); + ggml_build_forward_expand(gf, last_norm_hidden); + + // Optionally expose the full [n_embd, n_tokens] post-norm hidden sequence + // (used by Qwen3.6 MTP warm_head_kv). Default off: keeping this output + // pinned across compute would otherwise reserve ~7.5MB at ubatch=384 for + // the non-MTP target_gen path that throws the sequence away. + ggml_tensor * all_norm_hidden = nullptr; + if (in.capture_all_norm_hidden) { + all_norm_hidden = out; + ggml_set_name(all_norm_hidden, "all_norm_hidden"); + ggml_set_output(all_norm_hidden); + ggml_build_forward_expand(gf, all_norm_hidden); + } + // 3. LM head — optionally only for the last token (prefill optimization: // reduces logits from [vocab, n_tokens] to [vocab, 1], saving ~233MB // scratch at ubatch=384 and eliminating a large matmul). @@ -1163,6 +1234,10 @@ QwenGraphOutputs build_qwen35_graph( QwenGraphOutputs og = std::move(og_early); og.logits = logits; + og.last_norm_hidden = last_norm_hidden; + og.all_norm_hidden = all_norm_hidden; + og.last_h_pre_norm = last_h_pre_norm; + og.all_h_pre_norm = all_h_pre_norm; return og; } diff --git a/dflash/src/server/server_main.cpp b/dflash/src/server/server_main.cpp index 319f97de..cc7a22db 100644 --- a/dflash/src/server/server_main.cpp +++ b/dflash/src/server/server_main.cpp @@ -69,6 +69,17 @@ static void print_usage(const char * prog) { " --prefill-drafter Drafter GGUF for compression (Qwen3-0.6B)\n" " --prefill-skip-park Skip park/unpark (for >=32GB GPUs)\n" "\n" + "MTP speculative decoding (mutually exclusive with --draft):\n" + " --mtp-source \n" + " MTP source (default: auto when --mtp-gamma given)\n" + " none = disable MTP\n" + " native = MTP heads in target GGUF (unsloth single-file)\n" + " external = separate GGUF via --mtp-gguf\n" + " auto = probe target GGUF; native if found, else none\n" + " --mtp-gguf MTP GGUF path (required only for --mtp-source external)\n" + " --mtp-gamma Speculation chain depth (default: 0 = disabled)\n" + " --mtp-draft-topk Top-k draft strategy (default: chain; >1 enables mtp_topk)\n" + "\n" "Disk KV cache:\n" " --kv-cache-dir Directory for ondisk KV cache (enables feature)\n" " --kv-cache-budget Max disk usage in MB (default: 4096)\n" @@ -140,6 +151,34 @@ int main(int argc, char ** argv) { sconfig.pflash_drafter_path = argv[++i]; } else if (std::strcmp(argv[i], "--prefill-skip-park") == 0) { sconfig.pflash_skip_park = true; + } else if (std::strcmp(argv[i], "--mtp-source") == 0 && i + 1 < argc) { + const char * src = argv[++i]; + if (std::strcmp(src, "none") == 0) + bargs.mtp_source = MtpSource::None; + else if (std::strcmp(src, "native") == 0) + bargs.mtp_source = MtpSource::Native; + else if (std::strcmp(src, "external") == 0) + bargs.mtp_source = MtpSource::ExternalDrafter; + else if (std::strcmp(src, "auto") == 0) + bargs.mtp_source = MtpSource::Auto; + else { + std::fprintf(stderr, "[server] unknown --mtp-source: '%s' (expected: none|native|external|auto)\n", src); + print_usage(argv[0]); + return 1; + } + } else if (std::strcmp(argv[i], "--mtp-gguf") == 0 && i + 1 < argc) { + bargs.mtp_gguf_path = argv[++i]; + } else if (std::strcmp(argv[i], "--mtp-gamma") == 0 && i + 1 < argc) { + bargs.mtp_gamma = std::atoi(argv[++i]); + } else if (std::strcmp(argv[i], "--mtp-draft-source") == 0 && i + 1 < argc) { + ++i; // consume the argument + std::fprintf(stderr, + "[server] WARNING: --mtp-draft-source is deprecated. " + "Use --mtp-source [none|native|external|auto] and " + "--mtp-draft-topk instead.\n"); + } else if (std::strcmp(argv[i], "--mtp-draft-topk") == 0 && i + 1 < argc) { + bargs.mtp_draft_topk = std::atoi(argv[++i]); + if (bargs.mtp_draft_topk > 1) bargs.mtp_use_topk = true; } else if (std::strcmp(argv[i], "--kv-cache-dir") == 0 && i + 1 < argc) { sconfig.disk_cache_dir = argv[++i]; } else if (std::strcmp(argv[i], "--kv-cache-budget") == 0 && i + 1 < argc) { @@ -168,6 +207,42 @@ int main(int argc, char ** argv) { sconfig.max_ctx = bargs.device.max_ctx; } + // Infer MtpSource from legacy flags when --mtp-source is absent. + // --mtp-gguf without --mtp-source → ExternalDrafter (backward compat) + // --mtp-gamma without --mtp-source → Auto (probe the target GGUF) + // Only infer when Unset — explicit --mtp-source none must not be overridden. + if (bargs.mtp_source == MtpSource::Unset) { + if (bargs.mtp_gguf_path) { + bargs.mtp_source = MtpSource::ExternalDrafter; + } else if (bargs.mtp_gamma > 0) { + bargs.mtp_source = MtpSource::Auto; + } else { + bargs.mtp_source = MtpSource::None; + } + } + + // Validate: ExternalDrafter requires --mtp-gguf. + if (bargs.mtp_source == MtpSource::ExternalDrafter && !bargs.mtp_gguf_path) { + std::fprintf(stderr, + "[server] ERROR: --mtp-source external requires --mtp-gguf \n"); + return 1; + } + + // --draft and MTP are mutually exclusive; MTP wins if both are set. + // --draft suppression and env defaults gate on a concrete MTP source. + // Auto and Unset defer to backend resolution; this preserves --draft + // as a fallback when Auto resolves to None and avoids reserving MTP + // head_kv memory that may never be used. + const bool mtp_active = + (bargs.mtp_source == MtpSource::Native || + bargs.mtp_source == MtpSource::ExternalDrafter); + if (bargs.draft_path && mtp_active) { + std::fprintf(stderr, + "[server] WARNING: --draft and MTP both set; ignoring --draft.\n" + "[server] MTP speculation takes precedence over DFlash draft.\n"); + bargs.draft_path = nullptr; + } + // ── Apply environment defaults (mirrors server.py logic) ──────────── // Explicit --cache-type-k/v override via env vars. if (!cache_type_k.empty()) { @@ -186,6 +261,15 @@ int main(int argc, char ** argv) { } #endif + // Default MTP head_kv capacity to backbone max_ctx so prompts up to max_ctx + // never overflow the head_kv buffer (the old hardcoded 8192 caused a silent + // server crash when agentic prompts exceeded that length). + if (mtp_active && sconfig.max_ctx > 0) { + char ctx_str[32]; + std::snprintf(ctx_str, sizeof(ctx_str), "%d", sconfig.max_ctx); + setenv("DFLASH27B_MTP_CTX", ctx_str, 0); // don't overwrite user env + } + // PFlash performance defaults: BSA kernel + sparse alpha + full attention window. bool pflash_enabled = (sconfig.pflash_mode != ServerConfig::PflashMode::OFF); if (pflash_enabled) { @@ -245,6 +329,18 @@ int main(int argc, char ** argv) { std::fprintf(stderr, "[server] │ fa_window = %d\n", bargs.fa_window); std::fprintf(stderr, "[server] │ ddtree = %s\n", bargs.ddtree_mode ? "ON" : "off"); std::fprintf(stderr, "[server] │ ddtree_budget = %d\n", bargs.ddtree_budget); + if (mtp_active) { + const char * src_str = + bargs.mtp_source == MtpSource::Native ? "native" : + bargs.mtp_source == MtpSource::ExternalDrafter ? "external" : + bargs.mtp_source == MtpSource::Auto ? "auto" : "none"; + std::fprintf(stderr, "[server] │ mtp_source = %s\n", src_str); + if (bargs.mtp_gguf_path) + std::fprintf(stderr, "[server] │ mtp_gguf = %s\n", bargs.mtp_gguf_path); + std::fprintf(stderr, "[server] │ mtp_gamma = %d\n", bargs.mtp_gamma); + std::fprintf(stderr, "[server] │ mtp_draft_strat = %s\n", + bargs.mtp_use_topk ? "mtp_topk" : "chain (default)"); + } std::fprintf(stderr, "[server] │ cors = %s\n", sconfig.enable_cors ? "ON" : "off"); std::fprintf(stderr, "[server] │ cache_type_k = %s\n", #ifdef GGML_USE_HIP diff --git a/dflash/test/test_common_mtp_orchestrator.cpp b/dflash/test/test_common_mtp_orchestrator.cpp new file mode 100644 index 00000000..4e8aec8d --- /dev/null +++ b/dflash/test/test_common_mtp_orchestrator.cpp @@ -0,0 +1,812 @@ +// Unit test driving the extraction of MTP orchestration from +// qwen35_backend.cpp into dflash/src/common/mtp_orchestrator. +// +// Pins the public surface of the common helper and proves it handles the +// trivial guard cases — null backend, no MTP support — before any real +// backend is wired through it. +// +// T5-T10: MtpChainRunner state machine (gamma propagation, EOS, partial +// accept, n_gen termination, step failure, stats accounting). +// T11-T14: MtpOrchestrator lifecycle (reset_chain ordering, set_initial_hidden +// plumbing, gamma derivation, warm_head_kv gate). +// T15-T19: Qwen35MtpModule error paths (no GGUF required). +// +// Plain int main(), assert-based, mirrors test_kv_quant.cpp style. + +#include "common/mtp_orchestrator.h" +#include "common/mtp_chain_runner.h" +#include "common/model_backend.h" +#include "common/mtp_interface.h" +#include "common/dflash_target.h" +#include "qwen35/qwen35_mtp.h" + +#include +#include +#include +#include +#include + +namespace { + +// ── Base stubs (used by T1-T4 and reused by new tests) ─────────────────── + +// Minimal stub backend: implements every pure-virtual but every operation +// is a no-op or returns false. Lets us exercise the orchestrator's guard +// paths without loading any GPU weights. +struct StubBackend : public dflash::common::ModelBackend { + bool supports_mtp_value = false; + + void print_ready_banner() const override {} + bool park(const std::string &) override { return true; } + bool unpark(const std::string &) override { return true; } + bool is_target_parked() const override { return false; } + dflash::common::GenerateResult generate(const dflash::common::GenerateRequest &, + const dflash::common::DaemonIO &) override { + return {}; + } + bool snapshot_save(int) override { return false; } + void snapshot_free(int) override {} + bool snapshot_used(int) const override { return false; } + int snapshot_cur_pos(int) const override { return 0; } + dflash::common::GenerateResult restore_and_generate(int, + const dflash::common::GenerateRequest &, + const dflash::common::DaemonIO &) override { + return {}; + } + bool handle_compress(const std::string &, + const dflash::common::DaemonIO &) override { return false; } + void free_drafter() override {} + bool supports_mtp() const override { return supports_mtp_value; } + dflash::common::mtp::IMtpModule * mtp() override { return nullptr; } + void shutdown() override {} +}; + +// Mock NativeHeads MTP module for the "interface is generic" test (T4). +// Records calls to prove the orchestrator invokes the right virtuals. +struct StubMtpModule : public dflash::common::mtp::INativeMtp { + int reset_chain_calls = 0; + int set_initial_hidden_calls = 0; + int max_gamma() const override { return 3; } + int effective_gamma_value = 3; + int effective_gamma() const override { return effective_gamma_value; } + void set_effective_gamma(int g) override { effective_gamma_value = g; } + int hidden_size() const override { return 4; } + bool attach(dflash::common::DFlashTarget *) override { return true; } + void reset_chain() override { reset_chain_calls++; } + void shutdown() override {} + int num_heads() const override { return 3; } + bool step_batch(int32_t, int, + std::vector &) override { return true; } + void set_initial_hidden(const float *, int) override { set_initial_hidden_calls++; } +}; + +// Mock DFlashTarget that fails verify_batch — proves the orchestrator reached +// the prefill loop using only abstract DFlashTarget methods. +struct StubTarget : public dflash::common::DFlashTarget { + int verify_batch_calls = 0; + int hidden_size() const override { return 4; } + bool verify_batch(const std::vector &, int, int &, + std::vector *) override { + verify_batch_calls++; + return false; + } + bool verify_tree(const std::vector &, const dflash::common::DDTree &, + int, std::vector &, std::vector *) override { return false; } + bool snapshot_kv() override { return false; } + bool restore_kv() override { return false; } + bool restore_kv_at_dfs(const std::vector &) override { return false; } + bool restore_kv_at_chain(int) override { return false; } + void capture_topology_for_chain(int, int) override {} + bool is_eos(int) const override { return false; } + bool embed_tokens(const int32_t *, int, float *) const override { return false; } + bool project_hidden_to_tokens(const float *, int, + std::vector &) override { return false; } + bool project_hidden_to_logits(const float *, int, + std::vector &, int &) override { return false; } + int mask_token_id() const override { return 0; } + const std::vector & capture_layer_ids() const override { + static std::vector empty; + return empty; + } + ggml_backend * backend() const override { return nullptr; } + ggml_tensor * lm_head_weight() const override { return nullptr; } + int fa_window() const override { return 0; } +}; + +// Richer stub backend that returns valid mtp() + dflash_target() pointers, +// proving the orchestrator drives a generic ModelBackend without any +// Qwen35-specific cast or include. +struct FullStubBackend : public StubBackend { + StubMtpModule mtp_module; + StubTarget target; + FullStubBackend() { supports_mtp_value = true; } + dflash::common::mtp::IMtpModule * mtp() override { return &mtp_module; } + dflash::common::DFlashTarget * dflash_target() override { return ⌖ } +}; + +// ── Extended stubs for chain runner + orchestrator lifecycle tests ───────── + +// DFlashTarget that succeeds verify_batch and is configurable for chain tests. +// +// verify_batch behavior: +// - returns true, sets last_tok = argmax_token, fills all_argmax if requested +// - all_argmax[i] = candidate[i] for i < accept_n (accept), then diverge_token +// - this models the target accepting accept_n drafts + emitting a bonus token +// +// Also supports: snapshot_kv (succeeds), restore_kv (succeeds), +// restore_kv_at_chain (returns false to force slow rollback path), +// is_eos (returns true for eos_token_id when set). +struct SuccessStubTarget : public dflash::common::DFlashTarget { + int argmax_token = 42; // returned as last_tok and as the bonus token + int accept_n = 0; // how many candidates to "accept" in all_argmax + int eos_token_id = -1; // token for which is_eos returns true + int verify_calls = 0; + int hidden_sz = 4; + // Hidden seq buffer returned by last_hidden_seq (sized to last verify chunk). + mutable std::vector hidden_seq_buf; + mutable int hidden_seq_n = 0; + + int hidden_size() const override { return hidden_sz; } + + bool verify_batch(const std::vector & tokens, + int /*base_pos*/, + int & last_tok, + std::vector * all_argmax) override { + verify_calls++; + last_tok = argmax_token; + if (all_argmax) { + all_argmax->resize(tokens.size()); + for (int i = 0; i < (int)tokens.size(); i++) { + if (i < accept_n) { + // Accept: target's argmax matches the candidate (simulate + // the chain runner's matching logic: drafts[i] == all_argmax[i]). + // all_argmax[i] = tokens[i+1] when i < g_actual (drafts). + // But for the bonus slot the runner reads all_argmax[accept_n]. + // We simply set all_argmax[i] = tokens[i] so drafts match. + (*all_argmax)[i] = tokens[i]; + } else { + (*all_argmax)[i] = argmax_token; + } + } + } + // Populate hidden seq so orchestrator prefill path succeeds. + hidden_seq_n = (int)tokens.size(); + hidden_seq_buf.assign((size_t)hidden_seq_n * hidden_sz, 0.1f); + return true; + } + + bool verify_tree(const std::vector &, const dflash::common::DDTree &, + int, std::vector &, std::vector *) override { return false; } + + int restore_kv_at_chain_calls = 0; + + bool snapshot_kv() override { return true; } + bool restore_kv() override { return true; } + bool restore_kv_at_dfs(const std::vector &) override { return false; } + bool restore_kv_at_chain(int) override { restore_kv_at_chain_calls++; return false; } // force slow path + void capture_topology_for_chain(int, int) override {} + void enable_chain_capture(bool) override {} + + bool is_eos(int tok) const override { return eos_token_id >= 0 && tok == eos_token_id; } + bool embed_tokens(const int32_t *, int, float *) const override { return false; } + bool project_hidden_to_tokens(const float *, int, + std::vector &) override { return false; } + bool project_hidden_to_logits(const float *, int, + std::vector &, int &) override { return false; } + int mask_token_id() const override { return 0; } + const std::vector & capture_layer_ids() const override { + static std::vector empty; + return empty; + } + ggml_backend * backend() const override { return nullptr; } + ggml_tensor * lm_head_weight() const override { return nullptr; } + int fa_window() const override { return 0; } + + const float * last_hidden_seq(int * out_n) const override { + if (out_n) *out_n = hidden_seq_n; + return hidden_seq_n > 0 ? hidden_seq_buf.data() : nullptr; + } + const float * last_hidden() const override { + return hidden_seq_buf.empty() ? nullptr : hidden_seq_buf.data(); + } +}; + +// MTP module that emits a fixed draft token from step_batch. +// Extends StubMtpModule; overrides step_batch to emit `draft_token`. +struct DraftStubMtpModule : public StubMtpModule { + int32_t draft_token = 99; // always propose this token + int warm_head_kv_calls = 0; + + bool step_batch(int32_t, int, + std::vector & out) override { + dflash::common::mtp::StepOutput so; + so.draft_token = draft_token; + out.push_back(so); + return true; + } + + bool warm_head_kv(const int32_t *, int, int32_t, const float *) override { + warm_head_kv_calls++; + return true; + } +}; + +// MTP module whose step_chain always returns false (simulate module failure). +struct FailStepChainMtpModule : public StubMtpModule { + bool step_chain(int32_t, int, int, + std::vector &) override { + return false; + } +}; + +// Backend for orchestrator lifecycle tests: prefill succeeds, exposes +// DraftStubMtpModule so orchestrator can complete its full control flow. +struct LiveStubBackend : public StubBackend { + DraftStubMtpModule mtp_mod; + SuccessStubTarget target; + LiveStubBackend() { supports_mtp_value = true; } + dflash::common::mtp::IMtpModule * mtp() override { return &mtp_mod; } + dflash::common::DFlashTarget * dflash_target() override { return ⌖ } +}; + +} // namespace + +// ─── T1: null backend pointer ─────────────────────────────────────────────── + +static void t1_null_backend() { + dflash::common::GenerateRequest req; + dflash::common::DaemonIO io; + auto res = dflash::common::mtp::warm_and_decode(nullptr, req, io); + assert(!res.ok); + assert(res.error.find("backend") != std::string::npos); + std::puts("T1 null_backend PASS"); +} + +// ─── T2: backend that does NOT support MTP — orchestrator declines cleanly ── + +static void t2_backend_without_mtp() { + StubBackend b; + b.supports_mtp_value = false; + dflash::common::GenerateRequest req; + dflash::common::DaemonIO io; + auto res = dflash::common::mtp::warm_and_decode(&b, req, io); + assert(!res.ok); + assert(res.error.find("mtp") != std::string::npos + || res.error.find("MTP") != std::string::npos); + std::puts("T2 backend_without_mtp PASS"); +} + +// ─── T3: empty prompt — orchestrator declines with an explicit error ──────── + +static void t3_empty_prompt() { + StubBackend b; + b.supports_mtp_value = true; + dflash::common::GenerateRequest req; + req.n_gen = 8; + dflash::common::DaemonIO io; + auto res = dflash::common::mtp::warm_and_decode(&b, req, io); + assert(!res.ok); + assert(res.error.find("prompt") != std::string::npos); + std::puts("T3 empty_prompt PASS"); +} + +// ─── T4: orchestrator drives a generic ModelBackend through abstract +// interfaces only (proves logic in common/ is not Qwen35-specific). +// Module needs to expose dflash_target() in the abstract — currently +// ModelBackend::dflash_target() lives on the base. ─────────────── + +static void t4_generic_backend_dispatch() { + FullStubBackend b; + dflash::common::GenerateRequest req; + req.prompt = {1, 2, 3, 4}; + req.n_gen = 4; + dflash::common::DaemonIO io; + auto res = dflash::common::mtp::warm_and_decode(&b, req, io); + assert(!res.ok); + // Reached verify_batch (which stub fails) — proves orchestrator depends + // only on ModelBackend / DFlashTarget / IMtpModule abstractions. + assert(res.error.find("verify_batch") != std::string::npos); + assert(b.target.verify_batch_calls >= 1); + std::puts("T4 generic_backend_dispatch PASS"); +} + +// ─── T5: gamma propagation — runner uses the gamma passed to run(); higher +// gamma (when module max allows) produces more proposals per iter. ── + +static void t5_gamma_propagation() { + // Module max_gamma=3 (from StubMtpModule). draft_token=99 always. + // Target: accept_n=0, so every iter accepts 0 drafts + 1 bonus. + // With gamma=1: proposed=1 per iter; with gamma=2: proposed=2 per iter. + // We run n_gen=1 (one iter) to keep it simple and compare proposed counts. + + DraftStubMtpModule mod1; + mod1.effective_gamma_value = 1; + SuccessStubTarget tgt1; + tgt1.accept_n = 0; + dflash::common::SamplerCfg sampler; + dflash::common::mtp::MtpChainRunner runner1(mod1, tgt1, sampler); + dflash::common::GenerateRequest req1; + req1.n_gen = 1; + req1.stream = false; + dflash::common::DaemonIO io; + auto res1 = runner1.run(req1, io, /*last_prefill_token=*/10, /*committed_pos=*/4, /*gamma=*/1); + assert(res1.ok); + const int proposed_g1 = runner1.stats().total_proposed; + + DraftStubMtpModule mod2; + mod2.effective_gamma_value = 2; + SuccessStubTarget tgt2; + tgt2.accept_n = 0; + dflash::common::mtp::MtpChainRunner runner2(mod2, tgt2, sampler); + dflash::common::GenerateRequest req2; + req2.n_gen = 2; + req2.stream = false; + auto res2 = runner2.run(req2, io, /*last_prefill_token=*/10, /*committed_pos=*/4, /*gamma=*/2); + assert(res2.ok); + const int proposed_g2 = runner2.stats().total_proposed; + + // gamma=2 run has >= 2 proposed (one full iter with g=2 -> proposed=2). + assert(proposed_g1 >= 1); + assert(proposed_g2 >= proposed_g1); + std::puts("T5 gamma_propagation PASS"); +} + +// ─── T6: EOS termination — runner stops when target returns is_eos=true. ── + +static void t6_eos_termination() { + DraftStubMtpModule mod; + mod.effective_gamma_value = 1; + SuccessStubTarget tgt; + // Bonus token will be argmax_token=42. Mark 42 as EOS. + tgt.argmax_token = 42; + tgt.eos_token_id = 42; + tgt.accept_n = 0; + + dflash::common::SamplerCfg sampler; + dflash::common::mtp::MtpChainRunner runner(mod, tgt, sampler); + dflash::common::GenerateRequest req; + req.n_gen = 100; // large; EOS should stop it early + req.stream = false; + dflash::common::DaemonIO io; + auto res = runner.run(req, io, /*last_prefill_token=*/10, /*committed_pos=*/4, /*gamma=*/1); + assert(res.ok); + const auto & st = runner.stats(); + assert(st.eos_hits >= 1); + // EOS is the first token emitted (bonus=42=eos): runner stops immediately. + // total_emitted must be exactly 1, not the n_gen=100 cap or any loose bound. + assert(st.total_emitted == 1); + std::puts("T6 eos_termination PASS"); +} + +// ─── T7: partial-accept rollback — when target accepts K < gamma drafts, +// runner advances by K and stats.total_accepted grows by K per iter. ─ + +static void t7_partial_accept_rollback() { + // gamma=2, accept_n=1 in the target. + // Each iter: proposed 2, accepted 1, emitted 2 (1 accepted + 1 bonus). + DraftStubMtpModule mod; + mod.effective_gamma_value = 2; + SuccessStubTarget tgt; + tgt.accept_n = 1; + tgt.argmax_token = 55; // bonus token + + dflash::common::SamplerCfg sampler; + dflash::common::mtp::MtpChainRunner runner(mod, tgt, sampler); + dflash::common::GenerateRequest req; + req.n_gen = 2; + req.stream = false; + dflash::common::DaemonIO io; + auto res = runner.run(req, io, /*last_prefill_token=*/10, /*committed_pos=*/4, /*gamma=*/2); + assert(res.ok); + const auto & st = runner.stats(); + // Exactly 2 tokens emitted (1 accepted + 1 bonus = 2, capped by n_gen=2). + assert(st.total_accepted >= 1); + assert(st.total_emitted >= 1); + // Restore path was hit (accept_n < g_actual -> restore_kv_at_chain called). + assert(tgt.verify_calls >= 1); + assert(tgt.restore_kv_at_chain_calls >= 1); + std::puts("T7 partial_accept_rollback PASS"); +} + +// ─── T8: n_gen termination — runner emits exactly n_gen tokens when no EOS. ─ + +static void t8_n_gen_termination() { + DraftStubMtpModule mod; + mod.effective_gamma_value = 1; + SuccessStubTarget tgt; + tgt.argmax_token = 77; + tgt.eos_token_id = -1; // no EOS + tgt.accept_n = 0; + + dflash::common::SamplerCfg sampler; + dflash::common::mtp::MtpChainRunner runner(mod, tgt, sampler); + dflash::common::GenerateRequest req; + req.n_gen = 5; + req.stream = false; + dflash::common::DaemonIO io; + auto res = runner.run(req, io, /*last_prefill_token=*/10, /*committed_pos=*/4, /*gamma=*/1); + assert(res.ok); + assert((int)res.tokens.size() == 5); + const auto & st = runner.stats(); + assert(st.total_emitted == 5); + assert(st.eos_hits == 0); + std::puts("T8 n_gen_termination PASS"); +} + +// ─── T9: propose failure — when step_chain returns false, runner aborts and +// returns ok=false with error "mtp.propose". ──────────────────────── + +static void t9_propose_failure() { + FailStepChainMtpModule mod; + mod.effective_gamma_value = 1; + SuccessStubTarget tgt; + + dflash::common::SamplerCfg sampler; + dflash::common::mtp::MtpChainRunner runner(mod, tgt, sampler); + dflash::common::GenerateRequest req; + req.n_gen = 4; + req.stream = false; + dflash::common::DaemonIO io; + auto res = runner.run(req, io, /*last_prefill_token=*/10, /*committed_pos=*/4, /*gamma=*/1); + assert(!res.ok); + assert(res.error.find("propose") != std::string::npos + || res.error.find("mtp") != std::string::npos); + std::puts("T9 propose_failure PASS"); +} + +// ─── T10: stats accounting — total_emitted == total_accepted + total_iters ─ +// (each iter adds exactly 1 bonus token to emitted). + +static void t10_stats_accounting() { + DraftStubMtpModule mod; + mod.effective_gamma_value = 2; + SuccessStubTarget tgt; + tgt.accept_n = 1; // accept 1 draft + 1 bonus = 2 emitted per iter + tgt.argmax_token = 33; + tgt.eos_token_id = -1; + + dflash::common::SamplerCfg sampler; + dflash::common::mtp::MtpChainRunner runner(mod, tgt, sampler); + dflash::common::GenerateRequest req; + req.n_gen = 6; + req.stream = false; + dflash::common::DaemonIO io; + auto res = runner.run(req, io, /*last_prefill_token=*/10, /*committed_pos=*/4, /*gamma=*/2); + assert(res.ok); + const auto & st = runner.stats(); + // Invariant: each iter emits accept_n + 1 (one bonus), so: + // total_emitted == total_accepted + total_iters + assert(st.total_emitted == st.total_accepted + st.total_iters); + std::puts("T10 stats_accounting PASS"); +} + +// ─── T11: reset_chain called before chain runner drive ──────────────────── + +static void t11_reset_chain_before_drive() { + LiveStubBackend b; + b.mtp_mod.effective_gamma_value = 1; + dflash::common::GenerateRequest req; + req.prompt = {1, 2, 3}; + req.n_gen = 2; + dflash::common::DaemonIO io; + auto res = dflash::common::mtp::warm_and_decode(&b, req, io); + // reset_chain() is called once by the orchestrator before drive. + assert(b.mtp_mod.reset_chain_calls >= 1); + // Result should succeed (prefill passes, chain runs). + assert(res.ok); + std::puts("T11 reset_chain_before_drive PASS"); +} + +// ─── T12: set_initial_hidden plumbing ────────────────────────────────────── +// Orchestrator reads target->last_hidden() and forwards via +// module->set_initial_hidden(). Asserts the call count == 1. + +static void t12_set_initial_hidden_plumbing() { + LiveStubBackend b; + b.mtp_mod.effective_gamma_value = 1; + dflash::common::GenerateRequest req; + req.prompt = {5, 6, 7, 8}; + req.n_gen = 1; + dflash::common::DaemonIO io; + auto res = dflash::common::mtp::warm_and_decode(&b, req, io); + assert(res.ok); + // The orchestrator calls set_initial_hidden once (if last_hidden() != null). + // SuccessStubTarget::last_hidden() returns non-null after verify_batch, + // so the orchestrator must have called set_initial_hidden. + assert(b.mtp_mod.set_initial_hidden_calls >= 1); + std::puts("T12 set_initial_hidden_plumbing PASS"); +} + +// ─── T13: gamma derived from module::effective_gamma() ─────────────────── +// Orchestrator reads module->effective_gamma() and passes it to +// MtpChainRunner::run(). We set gamma=2, expect proposed >= 2. + +static void t13_gamma_derived_from_module() { + LiveStubBackend b; + b.mtp_mod.effective_gamma_value = 2; + b.mtp_mod.draft_token = 88; + b.target.accept_n = 0; + b.target.argmax_token = 55; + dflash::common::GenerateRequest req; + req.prompt = {1, 2}; + req.n_gen = 3; + dflash::common::DaemonIO io; + auto res = dflash::common::mtp::warm_and_decode(&b, req, io); + assert(res.ok); + // At least some tokens generated; can't directly inspect the runner's + // stats from here, but success proves the orchestrator read effective_gamma + // (gamma==0 would have returned error "effective_gamma() == 0"). + assert(res.tokens.size() >= 1); + std::puts("T13 gamma_derived_from_module PASS"); +} + +// ─── T14: zero effective_gamma rejected by orchestrator ────────────────── +// If module->effective_gamma() == 0, orchestrator must return +// an error rather than passing gamma=0 to the chain runner. + +static void t14_zero_gamma_rejected() { + LiveStubBackend b; + b.mtp_mod.effective_gamma_value = 0; // backend forgot to set gamma + dflash::common::GenerateRequest req; + req.prompt = {1, 2, 3}; + req.n_gen = 4; + dflash::common::DaemonIO io; + auto res = dflash::common::mtp::warm_and_decode(&b, req, io); + assert(!res.ok); + assert(res.error.find("effective_gamma") != std::string::npos + || res.error.find("gamma") != std::string::npos); + std::puts("T14 zero_gamma_rejected PASS"); +} + +// ─── T15: Qwen35MtpModule::attach(nullptr) returns false without crash ──── + +static void t15_attach_null_returns_false() { + dflash::common::mtp::Qwen35MtpModule mod; + // No init() — module is not loaded. attach(nullptr) must return false. + bool ok = mod.attach(nullptr); + assert(!ok); + std::puts("T15 attach_null_returns_false PASS"); +} + +// ─── T16: set_effective_gamma clamps to max_gamma() ────────────────────── +// Pre-init: max_gamma() == 0, so any positive gamma is clamped. +// Post attach_weights_for_test: max_gamma() == 8 (production ceiling). + +static void t16_set_effective_gamma_clamping() { + dflash::common::mtp::Qwen35MtpModule mod; + + // Pre-init: max_gamma()==0, so effective_gamma stays 0 after any set call. + mod.set_effective_gamma(5); + // Implementation: (gamma > 0) ? std::min(gamma, max_gamma()) : max_gamma() + // With max_gamma()==0: std::min(5, 0) == 0. + assert(mod.effective_gamma() == 0); + + // Inject minimal weights so loaded==true; max_gamma() returns 8. + dflash::common::mtp::Qwen35MtpWeights w; + w.n_embd = 4; + w.n_vocab = 16; + w.n_heads = 1; + w.n_backbone_layers = 1; + w.n_head_count = 1; + w.n_head_kv = 1; + w.n_key_length = 4; + w.n_value_length = 4; + w.n_ffn_length = 8; + w.heads.resize(1); + mod.attach_weights_for_test(w); + // max_gamma() should now be 8. + assert(mod.max_gamma() == 8); + + // Value within range: set 3 -> stays 3. + mod.set_effective_gamma(3); + assert(mod.effective_gamma() == 3); + + // Value above max: set 99 -> clamped to 8. + mod.set_effective_gamma(99); + assert(mod.effective_gamma() == mod.max_gamma()); + std::puts("T16 set_effective_gamma_clamping PASS"); +} + +// ─── T17: step_batch returns false when not attached ───────────────────── + +static void t17_step_batch_not_attached() { + dflash::common::mtp::Qwen35MtpModule mod; + // No init/attach — state.loaded==false. + std::vector out; + bool ok = mod.step_batch(0, 0, out); + assert(!ok); + assert(out.empty()); + std::puts("T17 step_batch_not_attached PASS"); +} + +// ─── T18: shutdown() is idempotent ──────────────────────────────────────── + +static void t18_shutdown_idempotent() { + dflash::common::mtp::Qwen35MtpModule mod; + // Two shutdown calls without init; should not crash. + mod.shutdown(); + mod.shutdown(); + // After double shutdown: max_gamma()==0 (not loaded). + assert(mod.max_gamma() == 0); + std::puts("T18 shutdown_idempotent PASS"); +} + +// ─── T19: reset_chain() before attach() is a safe no-op ────────────────── + +static void t19_reset_chain_before_attach() { + dflash::common::mtp::Qwen35MtpModule mod; + // reset_chain() checks state_->loaded; before init it should be safe. + mod.reset_chain(); + mod.reset_chain(); + // No crash, max_gamma still 0. + assert(mod.max_gamma() == 0); + std::puts("T19 reset_chain_before_attach PASS"); +} + +// ─── T20 support types ──────────────────────────────────────────────────── + +// ExternalDrafter stub: proposes two tokens (100, 101) per step chain and +// fills next_hidden so the fix branch (set_capture_row / consume_captured_hidden) +// is reachable. Records every call to the two methods under test. +struct StubExternalDrafter : public dflash::common::mtp::IExternalDrafterMtp { + // IMtpModule tuneables + int max_gamma_value = 2; + int effective_gamma_value = 2; + int hidden_size_value = 8; + + // Call counters — IMtpModule lifecycle + int reset_chain_calls = 0; + int set_initial_hidden_calls = 0; + int attach_calls = 0; + + // Call counters — IExternalDrafterMtp + int step_calls = 0; + int enable_capture_calls = 0; + int set_capture_row_calls = 0; + int set_capture_row_last_arg = -999; // sentinel: "never called" + int consume_calls = 0; + bool consume_after_set_capture = false; // ordering assertion + + std::vector donor_layers_value{0, 1}; + + // IMtpModule + int max_gamma() const override { return max_gamma_value; } + int effective_gamma() const override { return effective_gamma_value; } + void set_effective_gamma(int g) override { effective_gamma_value = g; } + int hidden_size() const override { return hidden_size_value; } + bool attach(dflash::common::DFlashTarget *) override { ++attach_calls; return true; } + void reset_chain() override { ++reset_chain_calls; } + void shutdown() override {} + void set_initial_hidden(const float *, int) override { ++set_initial_hidden_calls; } + + // IExternalDrafterMtp + bool step(const dflash::common::mtp::StepInput & in, + dflash::common::mtp::StepOutput & out) override { + ++step_calls; + // Propose token 100+gamma_index so each step produces a distinct token. + out.draft_token = 100 + in.gamma_index; + // Non-empty next_hidden activates the fix branch in the runner. + out.next_hidden.assign(hidden_size_value, 0.5f); + return true; + } + + const std::vector & donor_layers() const override { + return donor_layers_value; + } + + bool enable_target_hidden_capture(bool /*batch_mode*/, + int /*gamma_max*/) override { + ++enable_capture_calls; + return true; + } + + void set_capture_row(int row) override { + ++set_capture_row_calls; + set_capture_row_last_arg = row; + } + + bool consume_captured_hidden(float * out, int dim) override { + ++consume_calls; + // Order check: set_capture_row must have been called first. + consume_after_set_capture = (set_capture_row_calls > 0); + for (int i = 0; i < dim; ++i) out[i] = 1.0f; // sentinel value + return true; + } +}; + +// Target that forces the runner to reach accept_n=1 out of gamma=2. +// verify_batch is overridden so that: +// all_argmax[0] == drafts[0] (accept first token) +// all_argmax[1] != drafts[1] (reject second token — triggers partial-accept) +// The drafts are always 100 (gamma_index=0) and 101 (gamma_index=1), so we +// set all_argmax[0]=100 (matches) and all_argmax[1]=999 (diverges). +struct ExternalPartialTarget : public SuccessStubTarget { + static constexpr int32_t kDivergeToken = 999; + + bool verify_batch(const std::vector & tokens, + int base_pos, + int & last_tok, + std::vector * all_argmax) override { + // Let SuccessStubTarget handle bookkeeping (verify_calls, hidden_seq). + SuccessStubTarget::verify_batch(tokens, base_pos, last_tok, all_argmax); + // Override all_argmax to produce the desired partial-accept pattern. + // candidate = [cur_tok, 100, 101]; drafts = [100, 101]. + // all_argmax[0]=100 -> drafts[0]==100 matches (accept). + // all_argmax[1]=999 -> drafts[1]==101 != 999 (reject). + if (all_argmax && (int)all_argmax->size() >= 2) { + (*all_argmax)[0] = 100; // accept draft[0] + (*all_argmax)[1] = kDivergeToken; // reject draft[1] + } + last_tok = kDivergeToken; + return true; + } +}; + +// ─── T20: ExternalDrafter partial-accept threads committed-boundary hidden ─ +// +// Proves the 74f708a fix: on partial accept (accept_n < g_actual) the runner +// calls set_capture_row(accept_n) then consume_captured_hidden(), in that +// order, to thread the target-captured boundary hidden into running_hidden. +// T7 used a NativeHeads stub (next_hidden always empty) so it never exercised +// this branch. This test closes that gap. + +static void t20_external_drafter_partial_accept_threads_committed_row() { + StubExternalDrafter ext; + ExternalPartialTarget target; + // accept_n field on SuccessStubTarget is not used directly here (we + // override verify_batch above), but keep it consistent. + target.argmax_token = ExternalPartialTarget::kDivergeToken; + target.hidden_sz = ext.hidden_size_value; + + dflash::common::SamplerCfg sampler; + dflash::common::mtp::MtpChainRunner runner(ext, target, sampler); + + dflash::common::GenerateRequest req; + req.n_gen = 2; // enough for one full chain iteration + req.stream = false; + + dflash::common::DaemonIO io; + auto res = runner.run(req, io, /*last_prefill_token=*/10, /*committed_pos=*/4, + /*gamma=*/2); + + assert(res.ok && "runner returned error"); + + // The runner must have called set_capture_row with arg == accept_n (1). + assert(ext.set_capture_row_calls > 0 && "set_capture_row was not called"); + assert(ext.set_capture_row_last_arg == 1 && "set_capture_row arg != accept_n=1"); + + // consume_captured_hidden must have been called after set_capture_row. + assert(ext.consume_calls > 0 && "consume_captured_hidden was not called"); + assert(ext.consume_after_set_capture && "consume called BEFORE set_capture_row"); + + std::puts("T20 external_drafter_partial_accept_threads_committed_row PASS"); +} + +int main() { + t1_null_backend(); + t2_backend_without_mtp(); + t3_empty_prompt(); + t4_generic_backend_dispatch(); + // Area A: MtpChainRunner state machine + t5_gamma_propagation(); + t6_eos_termination(); + t7_partial_accept_rollback(); + t8_n_gen_termination(); + t9_propose_failure(); + t10_stats_accounting(); + // Area B: MtpOrchestrator lifecycle + t11_reset_chain_before_drive(); + t12_set_initial_hidden_plumbing(); + t13_gamma_derived_from_module(); + t14_zero_gamma_rejected(); + // Area C: Qwen35MtpModule error paths + t15_attach_null_returns_false(); + t16_set_effective_gamma_clamping(); + t17_step_batch_not_attached(); + t18_shutdown_idempotent(); + t19_reset_chain_before_attach(); + // Area D: ExternalDrafter partial-accept hidden threading (covers 74f708a fix) + t20_external_drafter_partial_accept_threads_committed_row(); + std::puts("ALL PASS"); + return 0; +} diff --git a/dflash/test/test_dflash.cpp b/dflash/test/test_dflash.cpp index 32ed24e2..740a9282 100644 --- a/dflash/test/test_dflash.cpp +++ b/dflash/test/test_dflash.cpp @@ -27,10 +27,13 @@ // dflash::common::run_laguna_daemon() instead of the // qwen35 + DFlash + DDTree pipeline below. #include "qwen35_daemon.h" // arch dispatch - single-GPU qwen35 daemon mode +#include "qwen35_backend.h" // Qwen3.6 MTP bench path reuses qwen35 backbone #include "qwen35_layer_split.h" // multi-GPU layer-split daemon args #include "layer_split_daemon_loop.h" // extracted layer-split daemon loop #include "qwen3_daemon.h" // arch dispatch - qwen3 (0.6B standalone) #include "gemma4_daemon.h" // arch dispatch - gemma4 (iSWA + MoE) +#include "qwen35/qwen35_mtp.h" +#include "common/mtp_chain_runner.h" #include "sampler.h" // shared CPU sampler chain (SamplerCfg / // sample_logits / parse_sampler_token) used by // both arches; behaviour stays identical. @@ -647,6 +650,485 @@ static int run_target_layer_split_harness( return 0; } +// use_topk: false = MtpChainRunner chain path (default), true = mtp_topk DDTree. +static int run_qwen35_mtp_harness(const char * target_path, + const char * mtp_gguf_path, + const char * prompt_path, + int n_gen, + const char * out_path, + int gamma, + int prompt_id, + int target_gpu, + int max_ctx, + bool use_topk, + int draft_topk, + int ddtree_budget, + bool ddtree_chain_seed, + float ddtree_temp) { + if (!target_path || !mtp_gguf_path || !prompt_path) { + std::fprintf(stderr, "qwen35-mtp requires target, --mtp-gguf, and --prompt-bin\n"); + return 2; + } + if (n_gen <= 0) { + std::fprintf(stderr, "qwen35-mtp requires --n-gen > 0\n"); + return 2; + } + if (gamma < 0) { + std::fprintf(stderr, "qwen35-mtp requires --gamma >= 0\n"); + return 2; + } + + std::vector prompt = read_int32_file(prompt_path); + if (prompt.empty()) { + std::fprintf(stderr, "qwen35-mtp empty prompt: %s\n", prompt_path); + return 1; + } + if ((int)prompt.size() + n_gen + std::max(1, gamma) + 1 > max_ctx) { + std::fprintf(stderr, + "qwen35-mtp prompt (%zu) + n_gen (%d) exceeds max_ctx (%d)\n", + prompt.size(), n_gen, max_ctx); + return 1; + } + + Qwen35Config cfg; + cfg.target_path = target_path; + cfg.draft_path = nullptr; + cfg.device.gpu = target_gpu; + cfg.device.max_ctx = max_ctx; + cfg.draft_gpu = target_gpu; + cfg.fa_window = g_fa_window; + cfg.kq_stride_pad = g_kq_stride_pad; + cfg.draft_ctx_max = 0; + + Qwen35Backend backend(cfg); + if (!backend.init()) { + std::fprintf(stderr, "qwen35-mtp backend init failed\n"); + return 1; + } + if (!backend.ensure_decode_cache(std::max(DFLASH27B_DRAFT_BLOCK_SIZE, gamma + 1))) { + std::fprintf(stderr, "qwen35-mtp decode cache: %s\n", dflash27b_last_error()); + return 1; + } + + DFlashTarget * target = backend.dflash_target(); + if (!target) { + std::fprintf(stderr, "qwen35-mtp target adapter unavailable\n"); + return 1; + } + + std::unique_ptr mtp_module; + if (gamma > 0) { + mtp_module = std::make_unique(); + std::string err; + if (!mtp_module->init(mtp_gguf_path, target, err)) { + std::fprintf(stderr, "qwen35-mtp init failed: %s\n", err.c_str()); + return 1; + } + if (!mtp_module->attach(target)) { + std::fprintf(stderr, "qwen35-mtp attach(target) failed\n"); + return 1; + } + // Shape B (PR 2e-final): the MTP module reads the backbone's final + // post-norm hidden via DFlashTarget::last_hidden() which is populated + // by Qwen35DFlashTarget after every verify_batch call. No backbone + // block attachment needed; set_initial_hidden() is called once after + // prefill (below) and the module auto-pulls target->last_hidden() for + // subsequent chain runner iterations. + } + + auto t_prefill0 = std::chrono::steady_clock::now(); + int32_t prefill_next = -1; + const int prefill_ubatch = 512; + // Accumulate the backbone's post-norm hidden for every prefill position so + // we can warm the MTP head's KV cache after prefill completes. + std::vector all_prefill_hidden; + if (mtp_module) { + all_prefill_hidden.resize((size_t)prompt.size() * target->hidden_size()); + } + for (int start = 0; start < (int)prompt.size(); start += prefill_ubatch) { + const int n = std::min(prefill_ubatch, (int)prompt.size() - start); + std::vector chunk(prompt.begin() + start, + prompt.begin() + start + n); + if (!target->verify_batch(chunk, start, prefill_next, nullptr)) { + std::fprintf(stderr, "qwen35-mtp prefill failed at %d\n", start); + return 1; + } + if (mtp_module) { + int n_chunk = 0; + const float * h_seq = target->last_hidden_seq(&n_chunk); + if (h_seq && n_chunk == n) { + std::memcpy(all_prefill_hidden.data() + + (size_t)start * target->hidden_size(), + h_seq, + sizeof(float) * (size_t)n * target->hidden_size()); + } else { + std::fprintf(stderr, + "qwen35-mtp prefill chunk hidden seq missing: " + "expected %d tokens, got %d — aborting warm\n", n, n_chunk); + all_prefill_hidden.clear(); // prevent warm_head_kv on bad buffer + break; + } + } + } + + // Seed the MTP module with the backbone's final hidden after prefill, and + // warm the head's KV cache over all prefill positions. + if (mtp_module && target->last_hidden()) { + mtp_module->set_initial_hidden(target->last_hidden(), target->hidden_size()); + } + if (mtp_module && !all_prefill_hidden.empty() && prefill_next >= 0) { + if (!mtp_module->warm_head_kv(prompt.data(), (int)prompt.size(), + prefill_next, all_prefill_hidden.data())) { + std::fprintf(stderr, "qwen35-mtp warm_head_kv failed\n"); + return 1; + } + } + auto t_prefill1 = std::chrono::steady_clock::now(); + const double prefill_s = std::chrono::duration(t_prefill1 - t_prefill0).count(); + if (prefill_next < 0) { + std::fprintf(stderr, "qwen35-mtp prefill produced invalid token\n"); + return 1; + } + + DaemonIO io; + io.stream_fd = -1; + std::vector generated; + generated.reserve(n_gen); + + auto t_decode0 = std::chrono::steady_clock::now(); + generated.push_back(prefill_next); + + int accepted = 0; + int proposed = 0; + if (!target->is_eos(prefill_next) && n_gen > 1) { + if (gamma == 0) { + int32_t cur = prefill_next; + int base_pos = (int)prompt.size(); + while ((int)generated.size() < n_gen) { + int32_t next = -1; + std::vector one{cur}; + if (!target->verify_batch(one, base_pos, next, nullptr)) { + std::fprintf(stderr, "qwen35-mtp AR decode failed at pos %d\n", base_pos); + return 1; + } + generated.push_back(next); + cur = next; + base_pos++; + if (target->is_eos(next)) break; + } + } else if (!use_topk) { + mtp_module->reset_chain(); + GenerateRequest req; + req.n_gen = n_gen - 1; + req.stream = false; + req.do_sample = false; + mtp::MtpChainRunner runner(*mtp_module, *target, SamplerCfg{}); + GenerateResult res = runner.run(req, io, + prefill_next, + (int)prompt.size(), + gamma); + if (!res.ok) { + std::fprintf(stderr, "qwen35-mtp runner failed: %s\n", res.error.c_str()); + return 1; + } + generated.insert(generated.end(), res.tokens.begin(), res.tokens.end()); + accepted = runner.stats().total_accepted; + proposed = runner.stats().total_proposed; + } else if (use_topk) { + // ── experiment C: MTP top-K → DDTree → chain-verify ───────── + // 1. step_batch with K>1 populates StepOutput.topk_logprobs/ids + // on every emitted head (length K, sorted DESCENDING). + // 2. Stack the per-head topk into [L × K] arrays for build_ddtree. + // 3. Build DDTree with the configured budget + chain_seed. + // 4. Verify the DDTree's top-1 chain against the target via + // verify_batch (chain verify with all_argmax). Sequential + // accept on first argmax-mismatch — same semantics as the + // existing MtpChainRunner, but the draft chain is sourced + // from the DDTree root-to-leaf top-1 path, NOT from + // StepOutput.draft_token directly. With K=1 this collapses + // to the chain path. With K>1 the DDTree may pick a + // different chain (e.g. via chain_seed=false best-first). + // + // BLOCKER: a true tree-mask verify against the target would + // need DFlashTarget to grow a tree-verify entry point (or to + // share test_dflash's spec-decode loop). Today we surface the + // DDTree build + report mean_tree_size to confirm the + // composition is invokable; the verify path is still chain. + mtp_module->set_draft_topk(std::max(1, draft_topk)); + mtp_module->reset_chain(); + const int L_max = mtp_module->num_heads(); + int32_t cur = prefill_next; + int base_pos = (int)prompt.size(); + const int K = std::max(1, draft_topk); + std::vector ddtree_logp; // [L_max × K], reused + std::vector ddtree_ids; // [L_max × K], reused + ddtree_logp.assign((size_t)L_max * K, 0.0f); + ddtree_ids.assign((size_t)L_max * K, 0); + long long sum_tree_size = 0; + int n_steps = 0; + while ((int)generated.size() < n_gen) { + std::vector outs; + if (!mtp_module->step_batch(cur, base_pos, outs)) { + std::fprintf(stderr, "qwen35-mtp[topk] step_batch failed at pos %d\n", base_pos); + return 1; + } + const int L = std::min((int)outs.size(), L_max); + if (L <= 0) break; + // Stack per-head top-K into [L × K]. With K=1 we synthesize + // a degenerate distribution from draft_logit so build_ddtree + // still emits a chain. + if (K == 1) { + for (int i = 0; i < L; i++) { + ddtree_logp[(size_t)i * K + 0] = 0.0f; + ddtree_ids [(size_t)i * K + 0] = outs[i].draft_token; + } + } else { + for (int i = 0; i < L; i++) { + if ((int)outs[i].topk_logprobs.size() != K || + (int)outs[i].topk_ids.size() != K) { + std::fprintf(stderr, + "qwen35-mtp[topk] head %d: expected K=%d topk entries, " + "got logp=%zu ids=%zu\n", + i, K, outs[i].topk_logprobs.size(), outs[i].topk_ids.size()); + return 1; + } + std::memcpy(ddtree_logp.data() + (size_t)i * K, + outs[i].topk_logprobs.data(), + sizeof(float) * K); + std::memcpy(ddtree_ids.data() + (size_t)i * K, + outs[i].topk_ids.data(), + sizeof(int32_t) * K); + } + } + DDTree tree = build_ddtree( + ddtree_logp.data(), ddtree_ids.data(), + L, K, + std::max(1, ddtree_budget), + ddtree_chain_seed); + // N = 1 + tree.n_nodes — count root + DFS-ordered nodes. + // Per Stage 2 brief: report mean_tree_size = N (the actual + // graph_compute batch size for tree-verify), not the + // accepted-path depth. + sum_tree_size += 1 + tree.n_nodes; + n_steps++; + (void)ddtree_temp; // temperature is consumed by extract_draft_topk + // when called from the external-drafter path; + // MTP path emits log-softmax directly. + + // Stage 1: try DFlashTarget::verify_tree first. If the + // target has a real tree-verify implementation (Stage 2+), + // this path commits accepted tree nodes + the next bonus + // token in one graph_compute. If it returns false (stub + // for n_nodes > 0 today), fall through to chain-verify of + // the DDTree's top-1 spine. + { + std::vector flat; + flat.reserve(1 + tree.n_nodes); + flat.push_back(cur); + for (int i = 0; i < tree.n_nodes; i++) flat.push_back(tree.token_ids[i]); + + std::vector tree_argmax; + if (target->verify_tree(flat, tree, base_pos, tree_argmax, /*out_logits=*/nullptr)) { + int next_token = -1; + int bonus_node_idx = 0; + std::vector accepted_path = follow_verified_tree( + tree, tree_argmax.data(), next_token, &bonus_node_idx); + const int accept_depth = (int)accepted_path.size(); // includes root + const int draft_depth = std::max(0, accept_depth - 1); + // Track how many DFS slots were actually committed to + // KV for restore_kv_at_dfs. We always commit root + // (= last bonus, slot 0), and each accepted child up + // to the n_gen cap. + int committed_dfs_n = 1; // root always committed + bool tt_eos_or_cap = false; + for (int i = 1; i < accept_depth; i++) { + const int node_idx = accepted_path[i]; // 1..n_nodes + const int32_t tok = tree.token_ids[node_idx - 1]; + generated.push_back(tok); + committed_dfs_n++; + if (target->is_eos(tok) || (int)generated.size() >= n_gen) { + cur = tok; tt_eos_or_cap = true; break; + } + } + if (!tt_eos_or_cap && next_token >= 0) { + generated.push_back((int32_t)next_token); + cur = (int32_t)next_token; + base_pos += draft_depth + 1; + } else if (!tt_eos_or_cap) { + // No bonus available (degenerate tree); advance + // only over the accepted draft nodes. + base_pos += draft_depth; + } + proposed += tree.n_nodes; + accepted += draft_depth; + if (tt_eos_or_cap || target->is_eos(cur) || (int)generated.size() >= n_gen) { + goto topk_done; + } + // Stage 3 (oracle blocker 5.3): roll back DeltaNet + // SSM/conv + full-attn KV to the deepest committed + // DFS slot so the next iter's verify sees the + // accepted-path tail (not the poisoned tail that + // included rejected siblings + bonus DFS slots). + // Without this, multi-iter tree-verify produces + // wrong output the moment any sibling subtree was + // forwarded but not accepted. + std::vector commit_prefix( + accepted_path.begin(), + accepted_path.begin() + committed_dfs_n); + if (!target->restore_kv_at_dfs(commit_prefix)) { + std::fprintf(stderr, + "qwen35-mtp[topk] restore_kv_at_dfs failed " + "(commit_n=%d, deepest_dfs=%d)\n", + committed_dfs_n, + committed_dfs_n > 0 + ? commit_prefix[committed_dfs_n - 1] : -1); + return 1; + } + continue; // next outer iter — skip chain-verify fallback + } + } + + // Build the DDTree's top-1 chain (root → deepest top-1 child). + // Slot 0 is the root (= last accepted token); we follow each + // node's first child (which build_ddtree places first in DFS + // order via chain_seed) until no children remain. + std::vector chain; + chain.reserve(L + 1); + { + int node = 0; // root + while ((int)chain.size() < L) { + // Find the first DFS child of `node`: the smallest + // index i in [1, n_nodes] whose parents[i] == node. + int first_child = -1; + for (int i = 1; i <= tree.n_nodes; i++) { + if (tree.parents[i] == node) { first_child = i; break; } + } + if (first_child < 0) break; + chain.push_back(tree.token_ids[first_child - 1]); + node = first_child; + } + } + if (chain.empty()) { + // Degenerate: empty tree. Fall back to argmax of head 0. + chain.push_back(outs[0].draft_token); + } + + // Chain-verify: send [cur, chain[0..g-1]] through verify_batch + // and accept on first argmax-mismatch (same semantics as + // MtpChainRunner.run()). + const int g = (int)chain.size(); + std::vector candidate; + candidate.reserve(g + 1); + candidate.push_back(cur); + for (int i = 0; i < g; i++) candidate.push_back(chain[i]); + std::vector all_argmax; + int32_t last_argmax = -1; + if (!target->verify_batch(candidate, base_pos, last_argmax, &all_argmax)) { + std::fprintf(stderr, "qwen35-mtp[topk] verify_batch failed at pos %d\n", base_pos); + return 1; + } + if ((int)all_argmax.size() < g + 1) { + std::fprintf(stderr, "qwen35-mtp[topk] verify_batch short: got %zu expected %d\n", + all_argmax.size(), g + 1); + return 1; + } + // all_argmax[i] is the target's argmax AT position base_pos+i, + // conditioned on tokens[0..i]. The "next correct" token for + // candidate[i] is candidate[i+1]; we accept while they match. + int accept_k = 0; + for (int i = 0; i < g; i++) { + if (chain[i] == all_argmax[i]) accept_k++; + else break; + } + proposed += g; + accepted += accept_k; + // Commit accept_k draft tokens + 1 bonus (the target's argmax + // at the first mismatch position, which is all_argmax[accept_k]). + for (int i = 0; i < accept_k; i++) { + generated.push_back(chain[i]); + if (target->is_eos(chain[i])) { cur = chain[i]; goto topk_done; } + if ((int)generated.size() >= n_gen) { cur = chain[i]; goto topk_done; } + } + { + int32_t bonus = all_argmax[accept_k]; + generated.push_back(bonus); + cur = bonus; + base_pos += accept_k + 1; + if (target->is_eos(bonus) || (int)generated.size() >= n_gen) break; + } + // NB: verify_batch wrote g+1 KV slots but we only want + // accept_k+1 committed. The existing MtpChainRunner solves + // this with snapshot_kv/restore_kv + recommit. For the + // experiment-C wiring (chain verify of top-1 path), the + // simple route is to restore and recommit on partial accept; + // we keep the bookkeeping simple and accept the same KV + // overhead the chain runner has — the bench is comparing + // tok/s, accept-rate, and tree-size, not raw KV efficiency. + if (accept_k < g) { + // KV currently holds candidate[0..g] starting at base_pos. + // We want only [accept_k+1] tokens committed. Restore the + // pre-verify snapshot is not available here (chain runner + // takes the snapshot/restore path); skip — the next + // verify_batch will overwrite the same KV slots, and the + // bonus position is re-processed. + base_pos += 0; // tracked above on the accept_k==g branch + } + } + topk_done:; + const double mean_tree_size = n_steps > 0 + ? (double)sum_tree_size / (double)n_steps : 0.0; + const double mean_gamma = proposed > 0 && n_steps > 0 + ? (double)proposed / (double)n_steps : 0.0; + std::fprintf(stderr, + "[qwen35-mtp topk] K=%d budget=%d chain_seed=%d steps=%d " + "mean_tree_size=%.2f mean_gamma=%.2f\n", + K, ddtree_budget, (int)ddtree_chain_seed, + n_steps, mean_tree_size, mean_gamma); + } + } + auto t_decode1 = std::chrono::steady_clock::now(); + const double decode_s = std::chrono::duration(t_decode1 - t_decode0).count(); + const double tok_s = decode_s > 0.0 ? (double)generated.size() / decode_s : 0.0; + + if (out_path && *out_path) { + std::vector all = prompt; + all.insert(all.end(), generated.begin(), generated.end()); + write_int32_file(out_path, all); + } + + std::printf("RESULT tok_s=%.2f prompt=%d gamma=%d tokens=%zu decode_s=%.6f prefill_s=%.6f accepted=%d proposed=%d\n", + tok_s, prompt_id, gamma, generated.size(), decode_s, prefill_s, + accepted, proposed); + // Single JSON line for downstream bench scripts (experiment C wiring). + // Always emitted so chain vs mtp_topk runs are comparable record-for-record. + { + const double accept_rate = proposed > 0 + ? (double)accepted / (double)proposed : 0.0; + const char * src = use_topk ? "mtp_topk" : "chain"; + std::printf("RESULT_JSON {" + "\"draft_source\":\"%s\"," + "\"gamma\":%d," + "\"draft_topk\":%d," + "\"ddtree_budget\":%d," + "\"ddtree_chain_seed\":%s," + "\"prompt_id\":%d," + "\"tokens\":%zu," + "\"decode_s\":%.6f," + "\"prefill_s\":%.6f," + "\"tok_s\":%.4f," + "\"accepted\":%d," + "\"proposed\":%d," + "\"accept_rate\":%.4f" + "}\n", + src, gamma, draft_topk, ddtree_budget, + ddtree_chain_seed ? "true" : "false", + prompt_id, generated.size(), decode_s, prefill_s, tok_s, + accepted, proposed, accept_rate); + } + std::fflush(stdout); + return 0; +} + // ─── Main ───────────────────────────────────────────────────────── int main(int argc, char ** argv) { @@ -756,6 +1238,15 @@ int main(int argc, char ** argv) { bool draft_feature_mirror = false; bool target_split_load_draft = false; bool target_split_dflash = false; + const char * mtp_gguf_path = nullptr; + const char * mtp_prompt_path = nullptr; + const char * mtp_out_path = nullptr; + int mtp_gamma = 2; + int mtp_n_gen = 0; + int mtp_prompt_id = 0; + // MTP draft strategy: false = chain (default), true = mtp_topk. + bool mtp_use_topk = false; + int mtp_draft_topk = 4; int target_gpu = 0; int draft_gpu = 0; const char * draft_ipc_bin = nullptr; @@ -826,6 +1317,54 @@ int main(int argc, char ** argv) { target_split_dflash = true; target_split_load_draft = true; } + else if (std::strncmp(argv[i], "--draft-source=", 15) == 0) { + mtp_use_topk = (std::strcmp(argv[i] + 15, "mtp_topk") == 0); + } + else if (std::strcmp(argv[i], "--draft-source") == 0) { + if (i + 1 < argc) mtp_use_topk = (std::strcmp(argv[++i], "mtp_topk") == 0); + } + else if (std::strncmp(argv[i], "--draft-topk=", 13) == 0) { + mtp_draft_topk = std::max(1, std::atoi(argv[i] + 13)); + } + else if (std::strcmp(argv[i], "--draft-topk") == 0) { + if (i + 1 < argc) mtp_draft_topk = std::max(1, std::atoi(argv[++i])); + } + else if (std::strncmp(argv[i], "--mtp-gguf=", 11) == 0) { + mtp_gguf_path = argv[i] + 11; + } + else if (std::strcmp(argv[i], "--mtp-gguf") == 0) { + if (i + 1 < argc) mtp_gguf_path = argv[++i]; + } + else if (std::strncmp(argv[i], "--gamma=", 8) == 0) { + mtp_gamma = std::atoi(argv[i] + 8); + } + else if (std::strcmp(argv[i], "--gamma") == 0) { + if (i + 1 < argc) mtp_gamma = std::atoi(argv[++i]); + } + else if (std::strncmp(argv[i], "--prompt-bin=", 13) == 0) { + mtp_prompt_path = argv[i] + 13; + } + else if (std::strcmp(argv[i], "--prompt-bin") == 0) { + if (i + 1 < argc) mtp_prompt_path = argv[++i]; + } + else if (std::strncmp(argv[i], "--n-gen=", 8) == 0) { + mtp_n_gen = std::atoi(argv[i] + 8); + } + else if (std::strcmp(argv[i], "--n-gen") == 0) { + if (i + 1 < argc) mtp_n_gen = std::atoi(argv[++i]); + } + else if (std::strncmp(argv[i], "--out=", 6) == 0) { + mtp_out_path = argv[i] + 6; + } + else if (std::strcmp(argv[i], "--out") == 0) { + if (i + 1 < argc) mtp_out_path = argv[++i]; + } + else if (std::strncmp(argv[i], "--prompt-id=", 12) == 0) { + mtp_prompt_id = std::atoi(argv[i] + 12); + } + else if (std::strcmp(argv[i], "--prompt-id") == 0) { + if (i + 1 < argc) mtp_prompt_id = std::atoi(argv[++i]); + } else if (std::strncmp(argv[i], "--target-gpu=", 13) == 0) { target_gpu = std::max(0, std::atoi(argv[i] + 13)); } @@ -939,7 +1478,14 @@ int main(int argc, char ** argv) { g_kq_stride_pad = 256; } - if (!is_laguna && !daemon_mode && !test_window_mode && (!prompt_path || !out_path)) { + if (mtp_gguf_path) { + if (mtp_prompt_path) prompt_path = mtp_prompt_path; + if (mtp_n_gen > 0) n_gen = mtp_n_gen; + if (mtp_out_path) out_path = mtp_out_path; + } + + if (!is_laguna && !daemon_mode && !test_window_mode && !mtp_gguf_path && + (!prompt_path || !out_path)) { std::fprintf(stderr, "Missing positional arguments for non-daemon mode.\n"); return 2; } @@ -1067,6 +1613,58 @@ int main(int argc, char ** argv) { target_gpu, draft_gpu, cuda_device_count); return 2; } + if (mtp_gguf_path) { + if (target_gpus.size() > 1) { + std::fprintf(stderr, "qwen35-mtp does not support --target-gpus\n"); + return 2; + } + const int max_ctx_eff = g_max_ctx_override > 0 ? g_max_ctx_override : 4096; + // ---- MTP daemon path: load once, serve requests via daemon protocol ---- + if (daemon_mode) { + std::fprintf(stderr, + "[test_dflash] arch=qwen35+mtp daemon -> dispatching to run_qwen35_daemon " + "(mtp=%s gamma=%d max_ctx=%d stream_fd=%d)\n", + mtp_gguf_path, mtp_gamma, max_ctx_eff, stream_fd); + dflash::common::Qwen35DaemonArgs qargs; + qargs.target_path = target_path; + qargs.draft_path = nullptr; // MTP mode: no DFlash draft + qargs.device.gpu = target_gpu; + qargs.device.max_ctx = max_ctx_eff; + qargs.draft_gpu = target_gpu; + qargs.stream_fd = stream_fd; + qargs.chunk = 512; + qargs.fa_window = g_fa_window; + qargs.kq_stride_pad = g_kq_stride_pad; + qargs.draft_swa_window = 0; + qargs.draft_ctx_max = 0; + qargs.fast_rollback = false; + qargs.seq_verify = false; + qargs.ddtree_mode = ddtree_budget > 0 && ddtree_mode; + qargs.ddtree_budget = ddtree_budget; + qargs.ddtree_temp = ddtree_temp; + qargs.ddtree_chain_seed = ddtree_chain_seed; + qargs.use_feature_mirror = false; + qargs.mtp_source = dflash::common::MtpSource::ExternalDrafter; + qargs.mtp_gguf_path = mtp_gguf_path; + qargs.mtp_gamma = mtp_gamma; + qargs.mtp_use_topk = mtp_use_topk; + qargs.mtp_draft_topk = mtp_draft_topk; + return dflash::common::run_qwen35_daemon(qargs); + } + // ---- MTP file-mode harness (bench / one-shot) ---- + std::fprintf(stderr, + "[test_dflash] qwen35-mtp bench target=%s mtp=%s gamma=%d max_ctx=%d\n", + target_path, mtp_gguf_path, mtp_gamma, max_ctx_eff); + return run_qwen35_mtp_harness(target_path, mtp_gguf_path, + prompt_path, n_gen, out_path, + mtp_gamma, mtp_prompt_id, + target_gpu, max_ctx_eff, + mtp_use_topk, + mtp_draft_topk, + ddtree_budget, + ddtree_chain_seed, + ddtree_temp); + } if (target_gpus.size() > 1) { if (test_window_mode || profile_scaling) { std::fprintf(stderr, "--target-gpus path does not support test-window/profile-scaling modes\n");