Skip to content

Commit e9cd58f

Browse files
committed
mtp: native-heads MTP speculator (Qwen3.6 NextN, γ-chain)
Ports the Qwen3.6 MTP head onto the qwen35 backbone (same arch, NextN block at layer n_layer-1). Speculation runs through a new common chain runner; the existing DFlashTarget adapter handles verify/snapshot/restore. - common/mtp_interface.h: flavor-tagged IMtpModule + INativeMtp / IExternalDrafterMtp mixins. Future Gemma4 drafter plugs in via IExternalDrafterMtp without touching the chain runner. - common/mtp_chain_runner.{h,cpp}: γ-chain propose/verify/accept loop, hoisted out of the backend. Three KV-reconciliation paths (accept-all / fast rollback / recommit) share a single post-iter invariant so AR equivalence holds under recommit. - common/mtp_orchestrator.{h,cpp}: chunked prefill + warm + dispatch to chain runner. Owns only control flow; all compute lives in DFlashTarget::verify_batch and INativeMtp::step_batch graphs on the backend device. - qwen36/qwen36_mtp.{h,cpp,_graph.cpp,_loader.cpp}: GGUF tensor inventory for Qwen3.6 -MTP-GGUF, GPU warm graph, GPU step graph cached on (head_idx, fa_window, fused_lm_head, topk_k). γ is bound at attach time as the single source of truth. - qwen35: supports_mtp()/mtp() exposed through ModelBackend; generate() delegates to common::mtp::warm_and_decode when MTP is configured. Cache sized for max(γ+1, ddtree_budget+1) verify tokens. - server.py: --mtp-gguf and --mtp-gamma flags routed through; daemon command surface unchanged. Tests: 4/4 test_common_mtp_orchestrator. Full build green; harness probe 7/7 (claude_code, codex, opencode, openwebui, pi, hermes, openclaw) at --max-ctx 65536; MTP decode reports accept_rate 0.43-0.88 on short agentic prompts.
1 parent 73433ee commit e9cd58f

31 files changed

Lines changed: 6559 additions & 35 deletions

dflash/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ add_library(dflash27b STATIC
236236
src/qwen35/qwen35_layer_split_dflash_target.cpp
237237
src/qwen35/layer_split_daemon_loop.cpp
238238
src/qwen35/qwen35_daemon.cpp
239+
src/qwen36/qwen36_mtp.cpp
240+
src/qwen36/qwen36_mtp_graph.cpp
241+
src/qwen36/qwen36_mtp_loader.cpp
242+
src/common/mtp_chain_runner.cpp
243+
src/common/mtp_orchestrator.cpp
239244
src/common/sampler.cpp
240245
src/common/daemon_loop.cpp
241246
src/common/gguf_inspect.cpp
@@ -491,6 +496,11 @@ if(DFLASH27B_TESTS)
491496
target_include_directories(test_kv_quant PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS})
492497
target_link_libraries(test_kv_quant PRIVATE dflash27b)
493498
endif()
499+
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_common_mtp_orchestrator.cpp")
500+
add_executable(test_common_mtp_orchestrator test/test_common_mtp_orchestrator.cpp)
501+
target_include_directories(test_common_mtp_orchestrator PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS})
502+
target_link_libraries(test_common_mtp_orchestrator PRIVATE dflash27b)
503+
endif()
494504
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_draft_vs_reference.cpp")
495505
add_executable(test_draft_vs_reference test/test_draft_vs_reference.cpp)
496506
target_link_libraries(test_draft_vs_reference PRIVATE dflash27b)

dflash/scripts/server.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,11 @@ def build_app(target: Path, draft: Path | None, bin_path: Path, budget: int, max
734734
verify_mode: str = "ddtree",
735735
extra_daemon_args: list[str] | None = None,
736736
lazy_draft: bool = False,
737-
verbose_daemon: bool = False) -> FastAPI:
737+
verbose_daemon: bool = False,
738+
mtp_gguf: Path | None = None,
739+
mtp_gamma: int = 3,
740+
mtp_draft_source: str = "chain",
741+
mtp_draft_topk: int = 1) -> FastAPI:
738742
import asyncio
739743
if _extra_daemon_has_target_sharding(extra_daemon_args):
740744
if prefix_cache_slots > 0 or prefill_cache_slots > 0:
@@ -791,6 +795,19 @@ async def _openai_compat_error_handler(_request: Request, exc: OpenAICompatError
791795
cmd = [bin_abs, str(target), "--daemon",
792796
f"--max-ctx={max_ctx}",
793797
f"--stream-fd={stream_fd_val}"]
798+
elif mtp_gguf is not None:
799+
# MTP mode: no --draft (MTP head lives inside target or mtp_gguf),
800+
# no DFlash flags. Daemon dispatches to MTP code path via --mtp-gguf.
801+
cmd = [bin_abs, str(target), "--daemon",
802+
f"--max-ctx={max_ctx}",
803+
f"--stream-fd={stream_fd_val}",
804+
f"--mtp-gguf={mtp_gguf}",
805+
f"--gamma={mtp_gamma}",
806+
"--draft-source", mtp_draft_source]
807+
if mtp_draft_source == "mtp_topk":
808+
cmd.append(f"--draft-topk={mtp_draft_topk}")
809+
if extra_daemon_args:
810+
cmd.extend(extra_daemon_args)
794811
else:
795812
if draft is None:
796813
raise SystemExit("qwen35 arch requires --draft <draft.gguf|model.safetensors>")
@@ -2737,6 +2754,20 @@ def main():
27372754
help="Pass --draft-feature-mirror to test_dflash (safe cross-GPU feature path)")
27382755
ap.add_argument("--peer-access", action="store_true",
27392756
help="Pass --peer-access to test_dflash (prefer P2P memcpy when available)")
2757+
# ── MTP (Multi-Token Prediction) speculator ──────────────────────────────
2758+
# When --mtp-gguf is set, the daemon runs MTP-head speculation instead of
2759+
# DFlash+DDTree. --draft is ignored (the MTP head is in the same GGUF as
2760+
# target, or a separate fused GGUF). Prefix-cache slots are auto-disabled
2761+
# in MTP mode because RESTORE does not snapshot MTP head KV yet.
2762+
ap.add_argument("--mtp-gguf", type=Path, default=None,
2763+
help="Path to MTP-fused GGUF. When set, daemon runs MTP "
2764+
"speculation; --draft and DFlash flags are ignored.")
2765+
ap.add_argument("--mtp-gamma", type=int, default=3,
2766+
help="MTP chain depth (default 3; recommended D=3 per matrix bench)")
2767+
ap.add_argument("--mtp-draft-source", choices=["chain", "mtp_topk"], default="chain",
2768+
help="MTP draft generation strategy (default chain)")
2769+
ap.add_argument("--mtp-draft-topk", type=int, default=1,
2770+
help="Top-K for mtp_topk draft source (default 1, ignored for chain)")
27402771
add_cli_flags(ap)
27412772
args = ap.parse_args()
27422773
prefill_cfg = config_from_args(args)
@@ -2782,6 +2813,17 @@ def main():
27822813
# through the laguna daemon now, so --prefill-compression and
27832814
# --prefix-cache-slots behave the same as on the qwen35 path.
27842815
draft = None
2816+
elif args.mtp_gguf is not None:
2817+
# MTP mode: --draft is ignored; MTP head lives in the target (or in --mtp-gguf
2818+
# if separate). Force prefix/prefill cache off — RESTORE doesn't snapshot
2819+
# MTP head KV yet (planned for a follow-up PR).
2820+
if not args.mtp_gguf.is_file():
2821+
raise SystemExit(f"--mtp-gguf not found at {args.mtp_gguf}")
2822+
draft = None
2823+
if args.prefix_cache_slots > 0 or args.prefill_cache_slots > 0:
2824+
print(" [cfg] MTP mode: disabling prefix/prefill cache (MTP head KV snapshot not implemented)")
2825+
args.prefix_cache_slots = 0
2826+
args.prefill_cache_slots = 0
27852827
else:
27862828
draft = resolve_draft(args.draft) if args.draft.is_dir() else args.draft
27872829
if not draft.is_file():
@@ -2813,7 +2855,11 @@ def main():
28132855
verify_mode=args.verify_mode,
28142856
extra_daemon_args=placement.daemon_args or None,
28152857
lazy_draft=args.lazy_draft,
2816-
verbose_daemon=args.verbose_daemon)
2858+
verbose_daemon=args.verbose_daemon,
2859+
mtp_gguf=args.mtp_gguf,
2860+
mtp_gamma=args.mtp_gamma,
2861+
mtp_draft_source=args.mtp_draft_source,
2862+
mtp_draft_topk=args.mtp_draft_topk)
28172863

28182864
import uvicorn
28192865
logging.basicConfig(

dflash/src/common/attn_masks.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,24 @@ inline void build_causal_mask(std::vector<uint16_t> & out,
4646
// Build an ancestor-only attention mask for DDTree tree-structured verify.
4747
// Each query position i can attend to its ancestors in the tree (including
4848
// itself) plus all past KV positions.
49+
//
50+
// kv_pad_override: when nonzero, pin the kv (column) stride to this value
51+
// instead of the helper's natural `align_up(past_length + N - win_start,
52+
// kq_stride_pad)`. Needed when the consumer tensor was allocated with a
53+
// fixed kv extent (e.g. build_target_step_tree sizes sg.attn_mask at
54+
// align_up(cache.max_ctx + N, kq_stride_pad)) and the helper-computed
55+
// stride would not match the tensor's actual row pitch. Default 0 keeps
56+
// existing behavior.
4957
inline void build_tree_mask(const DDTree & tree, int past_length,
5058
std::vector<uint16_t> & out_mask,
5159
int kq_stride_pad,
52-
int win_start = 0) {
60+
int win_start = 0,
61+
int kv_pad_override = 0) {
5362
const int N = 1 + tree.n_nodes;
5463
const int win_len = past_length + N - win_start;
55-
const int kv_pad = align_up(win_len, kq_stride_pad);
64+
const int kv_pad = kv_pad_override > 0
65+
? align_up(kv_pad_override, kq_stride_pad)
66+
: align_up(win_len, kq_stride_pad);
5667
const int q_pad = align_up(N, KQ_MASK_PAD);
5768
out_mask.assign((size_t)kv_pad * q_pad, F16_NEG_INF);
5869
for (int q = 0; q < N; q++) {

dflash/src/common/dflash_target.h

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,39 @@
1111

1212
#pragma once
1313

14+
#include "ddtree.h"
15+
1416
#include <cstdint>
1517
#include <vector>
1618

19+
struct ggml_backend;
20+
typedef struct ggml_backend * ggml_backend_t;
21+
struct ggml_tensor;
22+
1723
namespace dflash27b {
1824

1925
struct DFlashTarget {
2026
virtual ~DFlashTarget() = default;
2127

28+
// Return the ggml backend used by this target's graph compute. Default
29+
// returns nullptr; callers (e.g. Qwen3.6 MTP) that want to build CUDA
30+
// cgraphs against the same backend should check this and fall back if
31+
// it's null.
32+
virtual ggml_backend_t backend() const { return nullptr; }
33+
34+
// Optional: return the LM-head weight tensor on the target's backend
35+
// (shape [n_embd, n_vocab], used by ggml_mul_mat). When non-null, the
36+
// Qwen3.6 MTP step graph fuses `mul_mat(W, x_normed) -> argmax` into
37+
// its own cgraph, skipping a hidden -> host -> separate-cgraph round
38+
// trip per step. Default returns nullptr so existing targets (CPU
39+
// stubs) keep the project_hidden_to_* fallback path.
40+
virtual ggml_tensor * lm_head_weight() const { return nullptr; }
41+
42+
// Optional: causal attention window the target's full-attn blocks use
43+
// (kv_len - fa_window). The MTP head uses the same window so it sees
44+
// the same active context. 0 means full causal context.
45+
virtual int fa_window() const { return 0; }
46+
2247
// ── Target forward ──────────────────────────────────────────────
2348

2449
// Run a batch of tokens through the target model. Returns the argmax
@@ -33,6 +58,26 @@ struct DFlashTarget {
3358
int & last_tok,
3459
std::vector<int32_t> * all_argmax = nullptr) = 0;
3560

61+
// Tree-structured verify: run a flat DFS-ordered DDTree through the
62+
// target with an ancestor-only attention mask. `flat_tokens[0]` is the
63+
// tree root (= last accepted token), `flat_tokens[1..N-1]` are the
64+
// DFS-ordered tree nodes (mirroring DDTree::token_ids). `tree.n_nodes`
65+
// must equal flat_tokens.size() - 1. On success, `out_argmax` is the
66+
// target's argmax at each of the N tree positions (size == N) and
67+
// (if non-null) `out_logits` is the raw logits laid out as
68+
// [N × vocab] floats. Returns false by default so existing targets
69+
// that haven't wired tree-verify can be detected by callers; concrete
70+
// targets override to plug in build_target_step_tree + the tree mask.
71+
virtual bool verify_tree(const std::vector<int32_t> & flat_tokens,
72+
const DDTree & tree,
73+
int base_pos,
74+
std::vector<int32_t> & out_argmax,
75+
std::vector<float> * out_logits = nullptr) {
76+
(void)flat_tokens; (void)tree; (void)base_pos;
77+
(void)out_argmax; (void)out_logits;
78+
return false;
79+
}
80+
3681
// ── KV state management ─────────────────────────────────────────
3782

3883
// Snapshot KV cache state before speculative verify, so it can be
@@ -42,6 +87,32 @@ struct DFlashTarget {
4287
// Restore KV cache to the last snapshot (undo speculative forward).
4388
virtual bool restore_kv() = 0;
4489

90+
// Rollback DeltaNet SSM/conv + full-attn KV to the accepted-path tail of
91+
// the most recent verify_tree() call. accepted_dfs[0] must be 0 (root).
92+
// Returns false if unsupported; callers must treat false as fatal in
93+
// multi-iteration tree-spec loops (poisoned KV/SSM otherwise).
94+
virtual bool restore_kv_at_dfs(const std::vector<int> & accepted_dfs) {
95+
(void)accepted_dfs;
96+
return false;
97+
}
98+
99+
// Roll back DeltaNet SSM/conv + full-attn KV to slot `accept_n` of the
100+
// most recent verify_batch chain. Requires chain capture enabled.
101+
// Postcondition: cache cur_pos = base_pos + accept_n + 1.
102+
// Returns false if unsupported; chain runner falls back to snapshot+recommit.
103+
virtual bool restore_kv_at_chain(int accept_n) {
104+
(void)accept_n;
105+
return false;
106+
}
107+
108+
// Enable per-position DeltaNet intermediate capture in verify_batch.
109+
// Off by default; unsafe when n_tokens > max_verify_tokens.
110+
virtual void enable_chain_capture(bool /*on*/) {}
111+
112+
// Record linear-chain topology before verify_batch so restore_kv_at_chain()
113+
// can locate the rollback slot. Must be called before each capturable iter.
114+
virtual void capture_topology_for_chain(int /*n_tokens*/, int /*base_pos*/) {}
115+
45116
// ── Token utilities ─────────────────────────────────────────────
46117

47118
// Check if a token is end-of-sequence for this model.
@@ -61,6 +132,21 @@ struct DFlashTarget {
61132
int n_tokens,
62133
std::vector<int32_t> & tokens_out) = 0;
63134

135+
// Optional: project draft hidden states through the target's lm_head
136+
// and return the full raw logits (n_tokens * vocab floats) on host.
137+
// Used by MTP drafters that need a top-K surface for DDTree (the
138+
// argmax path above hides the distribution). Default returns false so
139+
// existing targets compile unchanged; concrete targets that wire it
140+
// up resize `logits_out` to n_tokens * vocab and return true. The
141+
// `out_vocab` param reports the vocab dim back to the caller.
142+
virtual bool project_hidden_to_logits(const float * /*hidden*/,
143+
int /*n_tokens*/,
144+
std::vector<float> & /*logits_out*/,
145+
int & out_vocab) {
146+
out_vocab = 0;
147+
return false;
148+
}
149+
64150
// ── Configuration for draft model ───────────────────────────────
65151

66152
// Target's hidden dimension (draft model must match).
@@ -72,6 +158,56 @@ struct DFlashTarget {
72158
// Which target layers to capture intermediate activations from.
73159
// The draft model's fc layer expects exactly this many feature slices.
74160
virtual const std::vector<int> & capture_layer_ids() const = 0;
161+
162+
// Return the backbone's final post-norm hidden state for the last committed
163+
// token (hidden_size() floats, F32). Populated by verify_batch.
164+
// Returns nullptr if not yet available (e.g. before first verify_batch).
165+
// Default implementation returns nullptr; Qwen35DFlashTarget overrides it.
166+
virtual const float * last_hidden() const { return nullptr; }
167+
168+
// Return the full post-norm hidden sequence from the MOST RECENT
169+
// verify_batch call: n_tokens * hidden_size() floats, F32, laid out as
170+
// [token_0_hidden, token_1_hidden, ..., token_{n_tokens-1}_hidden].
171+
// *out_n_tokens is set to the number of tokens captured (matches the
172+
// n_tokens passed to verify_batch). Default returns nullptr.
173+
virtual const float * last_hidden_seq(int * out_n_tokens) const {
174+
if (out_n_tokens) *out_n_tokens = 0;
175+
return nullptr;
176+
}
177+
178+
// Return the post-norm hidden at an ABSOLUTE sequence position, if that
179+
// position is covered by the most recent verify_batch's hidden capture.
180+
// The Qwen3.6 MTP head needs h_{base_pos-1} for its input pair at each
181+
// chain step, which equals last_hidden() only on the first chain step
182+
// (right after prefill); subsequent steps need a hidden from earlier in
183+
// the most recent verify_batch chunk. Returns nullptr if out of range.
184+
virtual const float * hidden_at_pos(int abs_pos) const {
185+
(void)abs_pos;
186+
return nullptr;
187+
}
188+
189+
// Pre-final-output-norm variant of hidden_at_pos. Mirrors llama.cpp
190+
// PR #22673's `t_h_pre_norm`. The Qwen3.6 MTP head's hnorm normalises
191+
// h_prev internally; feeding it the post-output-norm tensor double-
192+
// normalises and compounds per-depth rejection on D>=2 chains. Spec-
193+
// chain callers must prefer this accessor for the outer h_prev_0 seed
194+
// and fall back to hidden_at_pos() only if it returns nullptr (e.g.
195+
// adapters that do not yet capture the pre-norm sequence). Default
196+
// returns nullptr; Qwen35DFlashTarget overrides it when hidden-seq
197+
// capture is enabled.
198+
virtual const float * hidden_at_pos_pre_norm(int abs_pos) const {
199+
(void)abs_pos;
200+
return nullptr;
201+
}
202+
203+
// Enable per-position post-norm + pre-norm hidden capture during the
204+
// next verify_batch calls. Default no-op; Qwen35DFlashTarget overrides.
205+
virtual void enable_hidden_seq_capture(bool /*on*/) {}
206+
207+
// FULL_SEQ during prefill (warm_head_kv reads per-position); LAST_ROW_ONLY
208+
// during decode-side chain verifies. Default no-op.
209+
enum class VerifyCaptureScope { FULL_SEQ, LAST_ROW_ONLY };
210+
virtual void set_hidden_capture_scope(VerifyCaptureScope /*scope*/) {}
75211
};
76212

77213
} // namespace dflash27b

0 commit comments

Comments
 (0)