Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions dflash/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ add_library(dflash27b STATIC
src/qwen35/qwen35_layer_split_dflash_target.cpp
src/qwen35/layer_split_daemon_loop.cpp
src/qwen35/qwen35_daemon.cpp
src/qwen36/qwen36_mtp.cpp
src/qwen36/qwen36_mtp_graph.cpp
src/qwen36/qwen36_mtp_loader.cpp
src/common/mtp_chain_runner.cpp
src/common/mtp_orchestrator.cpp
src/common/sampler.cpp
src/common/daemon_loop.cpp
src/common/gguf_inspect.cpp
Expand Down Expand Up @@ -491,6 +496,11 @@ if(DFLASH27B_TESTS)
target_include_directories(test_kv_quant PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS})
target_link_libraries(test_kv_quant PRIVATE dflash27b)
endif()
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_common_mtp_orchestrator.cpp")
add_executable(test_common_mtp_orchestrator test/test_common_mtp_orchestrator.cpp)
target_include_directories(test_common_mtp_orchestrator PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS})
target_link_libraries(test_common_mtp_orchestrator PRIVATE dflash27b)
endif()
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_draft_vs_reference.cpp")
add_executable(test_draft_vs_reference test/test_draft_vs_reference.cpp)
target_link_libraries(test_draft_vs_reference PRIVATE dflash27b)
Expand Down
50 changes: 48 additions & 2 deletions dflash/scripts/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,11 @@ def build_app(target: Path, draft: Path | None, bin_path: Path, budget: int, max
verify_mode: str = "ddtree",
extra_daemon_args: list[str] | None = None,
lazy_draft: bool = False,
verbose_daemon: bool = False) -> FastAPI:
verbose_daemon: bool = False,
mtp_gguf: Path | None = None,
mtp_gamma: int = 3,
mtp_draft_source: str = "chain",
mtp_draft_topk: int = 1) -> FastAPI:
import asyncio
if _extra_daemon_has_target_sharding(extra_daemon_args):
if prefix_cache_slots > 0 or prefill_cache_slots > 0:
Expand Down Expand Up @@ -791,6 +795,19 @@ async def _openai_compat_error_handler(_request: Request, exc: OpenAICompatError
cmd = [bin_abs, str(target), "--daemon",
f"--max-ctx={max_ctx}",
f"--stream-fd={stream_fd_val}"]
elif mtp_gguf is not None:
# MTP mode: no --draft (MTP head lives inside target or mtp_gguf),
# no DFlash flags. Daemon dispatches to MTP code path via --mtp-gguf.
cmd = [bin_abs, str(target), "--daemon",
f"--max-ctx={max_ctx}",
f"--stream-fd={stream_fd_val}",
f"--mtp-gguf={mtp_gguf}",
f"--gamma={mtp_gamma}",
"--draft-source", mtp_draft_source]
if mtp_draft_source == "mtp_topk":
cmd.append(f"--draft-topk={mtp_draft_topk}")
if extra_daemon_args:
cmd.extend(extra_daemon_args)
else:
if draft is None:
raise SystemExit("qwen35 arch requires --draft <draft.gguf|model.safetensors>")
Expand Down Expand Up @@ -2737,6 +2754,20 @@ def main():
help="Pass --draft-feature-mirror to test_dflash (safe cross-GPU feature path)")
ap.add_argument("--peer-access", action="store_true",
help="Pass --peer-access to test_dflash (prefer P2P memcpy when available)")
# ── MTP (Multi-Token Prediction) speculator ──────────────────────────────
# When --mtp-gguf is set, the daemon runs MTP-head speculation instead of
# DFlash+DDTree. --draft is ignored (the MTP head is in the same GGUF as
# target, or a separate fused GGUF). Prefix-cache slots are auto-disabled
# in MTP mode because RESTORE does not snapshot MTP head KV yet.
ap.add_argument("--mtp-gguf", type=Path, default=None,
help="Path to MTP-fused GGUF. When set, daemon runs MTP "
"speculation; --draft and DFlash flags are ignored.")
ap.add_argument("--mtp-gamma", type=int, default=3,
help="MTP chain depth (default 3; recommended D=3 per matrix bench)")
ap.add_argument("--mtp-draft-source", choices=["chain", "mtp_topk"], default="chain",
help="MTP draft generation strategy (default chain)")
ap.add_argument("--mtp-draft-topk", type=int, default=1,
help="Top-K for mtp_topk draft source (default 1, ignored for chain)")
add_cli_flags(ap)
args = ap.parse_args()
prefill_cfg = config_from_args(args)
Expand Down Expand Up @@ -2782,6 +2813,17 @@ def main():
# through the laguna daemon now, so --prefill-compression and
# --prefix-cache-slots behave the same as on the qwen35 path.
draft = None
elif args.mtp_gguf is not None:
# MTP mode: --draft is ignored; MTP head lives in the target (or in --mtp-gguf
# if separate). Force prefix/prefill cache off — RESTORE doesn't snapshot
# MTP head KV yet (planned for a follow-up PR).
if not args.mtp_gguf.is_file():
raise SystemExit(f"--mtp-gguf not found at {args.mtp_gguf}")
draft = None
if args.prefix_cache_slots > 0 or args.prefill_cache_slots > 0:
print(" [cfg] MTP mode: disabling prefix/prefill cache (MTP head KV snapshot not implemented)")
args.prefix_cache_slots = 0
args.prefill_cache_slots = 0
else:
draft = resolve_draft(args.draft) if args.draft.is_dir() else args.draft
if not draft.is_file():
Expand Down Expand Up @@ -2813,7 +2855,11 @@ def main():
verify_mode=args.verify_mode,
extra_daemon_args=placement.daemon_args or None,
lazy_draft=args.lazy_draft,
verbose_daemon=args.verbose_daemon)
verbose_daemon=args.verbose_daemon,
mtp_gguf=args.mtp_gguf,
mtp_gamma=args.mtp_gamma,
mtp_draft_source=args.mtp_draft_source,
mtp_draft_topk=args.mtp_draft_topk)

import uvicorn
logging.basicConfig(
Expand Down
15 changes: 13 additions & 2 deletions dflash/src/common/attn_masks.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,24 @@ inline void build_causal_mask(std::vector<uint16_t> & out,
// Build an ancestor-only attention mask for DDTree tree-structured verify.
// Each query position i can attend to its ancestors in the tree (including
// itself) plus all past KV positions.
//
// kv_pad_override: when nonzero, pin the kv (column) stride to this value
// instead of the helper's natural `align_up(past_length + N - win_start,
// kq_stride_pad)`. Needed when the consumer tensor was allocated with a
// fixed kv extent (e.g. build_target_step_tree sizes sg.attn_mask at
// align_up(cache.max_ctx + N, kq_stride_pad)) and the helper-computed
// stride would not match the tensor's actual row pitch. Default 0 keeps
// existing behavior.
inline void build_tree_mask(const DDTree & tree, int past_length,
std::vector<uint16_t> & out_mask,
int kq_stride_pad,
int win_start = 0) {
int win_start = 0,
int kv_pad_override = 0) {
const int N = 1 + tree.n_nodes;
const int win_len = past_length + N - win_start;
const int kv_pad = align_up(win_len, kq_stride_pad);
const int kv_pad = kv_pad_override > 0
? align_up(kv_pad_override, kq_stride_pad)
: align_up(win_len, kq_stride_pad);
const int q_pad = align_up(N, KQ_MASK_PAD);
out_mask.assign((size_t)kv_pad * q_pad, F16_NEG_INF);
for (int q = 0; q < N; q++) {
Expand Down
136 changes: 136 additions & 0 deletions dflash/src/common/dflash_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,39 @@

#pragma once

#include "ddtree.h"

#include <cstdint>
#include <vector>

struct ggml_backend;
typedef struct ggml_backend * ggml_backend_t;
struct ggml_tensor;

namespace dflash27b {

struct DFlashTarget {
virtual ~DFlashTarget() = default;

// Return the ggml backend used by this target's graph compute. Default
// returns nullptr; callers (e.g. Qwen3.6 MTP) that want to build CUDA
// cgraphs against the same backend should check this and fall back if
// it's null.
virtual ggml_backend_t backend() const { return nullptr; }

// Optional: return the LM-head weight tensor on the target's backend
// (shape [n_embd, n_vocab], used by ggml_mul_mat). When non-null, the
// Qwen3.6 MTP step graph fuses `mul_mat(W, x_normed) -> argmax` into
// its own cgraph, skipping a hidden -> host -> separate-cgraph round
// trip per step. Default returns nullptr so existing targets (CPU
// stubs) keep the project_hidden_to_* fallback path.
virtual ggml_tensor * lm_head_weight() const { return nullptr; }

// Optional: causal attention window the target's full-attn blocks use
// (kv_len - fa_window). The MTP head uses the same window so it sees
// the same active context. 0 means full causal context.
virtual int fa_window() const { return 0; }

// ── Target forward ──────────────────────────────────────────────

// Run a batch of tokens through the target model. Returns the argmax
Expand All @@ -33,6 +58,26 @@ struct DFlashTarget {
int & last_tok,
std::vector<int32_t> * all_argmax = nullptr) = 0;

// Tree-structured verify: run a flat DFS-ordered DDTree through the
// target with an ancestor-only attention mask. `flat_tokens[0]` is the
// tree root (= last accepted token), `flat_tokens[1..N-1]` are the
// DFS-ordered tree nodes (mirroring DDTree::token_ids). `tree.n_nodes`
// must equal flat_tokens.size() - 1. On success, `out_argmax` is the
// target's argmax at each of the N tree positions (size == N) and
// (if non-null) `out_logits` is the raw logits laid out as
// [N × vocab] floats. Returns false by default so existing targets
// that haven't wired tree-verify can be detected by callers; concrete
// targets override to plug in build_target_step_tree + the tree mask.
virtual bool verify_tree(const std::vector<int32_t> & flat_tokens,
const DDTree & tree,
int base_pos,
std::vector<int32_t> & out_argmax,
std::vector<float> * out_logits = nullptr) {
(void)flat_tokens; (void)tree; (void)base_pos;
(void)out_argmax; (void)out_logits;
return false;
}

// ── KV state management ─────────────────────────────────────────

// Snapshot KV cache state before speculative verify, so it can be
Expand All @@ -42,6 +87,32 @@ struct DFlashTarget {
// Restore KV cache to the last snapshot (undo speculative forward).
virtual bool restore_kv() = 0;

// Rollback DeltaNet SSM/conv + full-attn KV to the accepted-path tail of
// the most recent verify_tree() call. accepted_dfs[0] must be 0 (root).
// Returns false if unsupported; callers must treat false as fatal in
// multi-iteration tree-spec loops (poisoned KV/SSM otherwise).
virtual bool restore_kv_at_dfs(const std::vector<int> & accepted_dfs) {
(void)accepted_dfs;
return false;
}

// Roll back DeltaNet SSM/conv + full-attn KV to slot `accept_n` of the
// most recent verify_batch chain. Requires chain capture enabled.
// Postcondition: cache cur_pos = base_pos + accept_n + 1.
// Returns false if unsupported; chain runner falls back to snapshot+recommit.
virtual bool restore_kv_at_chain(int accept_n) {
(void)accept_n;
return false;
}

// Enable per-position DeltaNet intermediate capture in verify_batch.
// Off by default; unsafe when n_tokens > max_verify_tokens.
virtual void enable_chain_capture(bool /*on*/) {}

// Record linear-chain topology before verify_batch so restore_kv_at_chain()
// can locate the rollback slot. Must be called before each capturable iter.
virtual void capture_topology_for_chain(int /*n_tokens*/, int /*base_pos*/) {}

// ── Token utilities ─────────────────────────────────────────────

// Check if a token is end-of-sequence for this model.
Expand All @@ -61,6 +132,21 @@ struct DFlashTarget {
int n_tokens,
std::vector<int32_t> & tokens_out) = 0;

// Optional: project draft hidden states through the target's lm_head
// and return the full raw logits (n_tokens * vocab floats) on host.
// Used by MTP drafters that need a top-K surface for DDTree (the
// argmax path above hides the distribution). Default returns false so
// existing targets compile unchanged; concrete targets that wire it
// up resize `logits_out` to n_tokens * vocab and return true. The
// `out_vocab` param reports the vocab dim back to the caller.
virtual bool project_hidden_to_logits(const float * /*hidden*/,
int /*n_tokens*/,
std::vector<float> & /*logits_out*/,
int & out_vocab) {
out_vocab = 0;
return false;
}

// ── Configuration for draft model ───────────────────────────────

// Target's hidden dimension (draft model must match).
Expand All @@ -72,6 +158,56 @@ struct DFlashTarget {
// Which target layers to capture intermediate activations from.
// The draft model's fc layer expects exactly this many feature slices.
virtual const std::vector<int> & capture_layer_ids() const = 0;

// Return the backbone's final post-norm hidden state for the last committed
// token (hidden_size() floats, F32). Populated by verify_batch.
// Returns nullptr if not yet available (e.g. before first verify_batch).
// Default implementation returns nullptr; Qwen35DFlashTarget overrides it.
virtual const float * last_hidden() const { return nullptr; }

// Return the full post-norm hidden sequence from the MOST RECENT
// verify_batch call: n_tokens * hidden_size() floats, F32, laid out as
// [token_0_hidden, token_1_hidden, ..., token_{n_tokens-1}_hidden].
// *out_n_tokens is set to the number of tokens captured (matches the
// n_tokens passed to verify_batch). Default returns nullptr.
virtual const float * last_hidden_seq(int * out_n_tokens) const {
if (out_n_tokens) *out_n_tokens = 0;
return nullptr;
}

// Return the post-norm hidden at an ABSOLUTE sequence position, if that
// position is covered by the most recent verify_batch's hidden capture.
// The Qwen3.6 MTP head needs h_{base_pos-1} for its input pair at each
// chain step, which equals last_hidden() only on the first chain step
// (right after prefill); subsequent steps need a hidden from earlier in
// the most recent verify_batch chunk. Returns nullptr if out of range.
virtual const float * hidden_at_pos(int abs_pos) const {
(void)abs_pos;
return nullptr;
}

// Pre-final-output-norm variant of hidden_at_pos. Mirrors llama.cpp
// PR #22673's `t_h_pre_norm`. The Qwen3.6 MTP head's hnorm normalises
// h_prev internally; feeding it the post-output-norm tensor double-
// normalises and compounds per-depth rejection on D>=2 chains. Spec-
// chain callers must prefer this accessor for the outer h_prev_0 seed
// and fall back to hidden_at_pos() only if it returns nullptr (e.g.
// adapters that do not yet capture the pre-norm sequence). Default
// returns nullptr; Qwen35DFlashTarget overrides it when hidden-seq
// capture is enabled.
virtual const float * hidden_at_pos_pre_norm(int abs_pos) const {
(void)abs_pos;
return nullptr;
}

// Enable per-position post-norm + pre-norm hidden capture during the
// next verify_batch calls. Default no-op; Qwen35DFlashTarget overrides.
virtual void enable_hidden_seq_capture(bool /*on*/) {}

// FULL_SEQ during prefill (warm_head_kv reads per-position); LAST_ROW_ONLY
// during decode-side chain verifies. Default no-op.
enum class VerifyCaptureScope { FULL_SEQ, LAST_ROW_ONLY };
virtual void set_hidden_capture_scope(VerifyCaptureScope /*scope*/) {}
};

} // namespace dflash27b
Loading
Loading