1111
1212#pragma once
1313
14+ #include " ddtree.h"
15+
1416#include < cstdint>
1517#include < vector>
1618
19+ struct ggml_backend ;
20+ typedef struct ggml_backend * ggml_backend_t ;
21+ struct ggml_tensor ;
22+
1723namespace dflash27b {
1824
1925struct DFlashTarget {
2026 virtual ~DFlashTarget () = default ;
2127
28+ // Return the ggml backend used by this target's graph compute. Default
29+ // returns nullptr; callers (e.g. Qwen3.6 MTP) that want to build CUDA
30+ // cgraphs against the same backend should check this and fall back if
31+ // it's null.
32+ virtual ggml_backend_t backend () const { return nullptr ; }
33+
34+ // Optional: return the LM-head weight tensor on the target's backend
35+ // (shape [n_embd, n_vocab], used by ggml_mul_mat). When non-null, the
36+ // Qwen3.6 MTP step graph fuses `mul_mat(W, x_normed) -> argmax` into
37+ // its own cgraph, skipping a hidden -> host -> separate-cgraph round
38+ // trip per step. Default returns nullptr so existing targets (CPU
39+ // stubs) keep the project_hidden_to_* fallback path.
40+ virtual ggml_tensor * lm_head_weight () const { return nullptr ; }
41+
42+ // Optional: causal attention window the target's full-attn blocks use
43+ // (kv_len - fa_window). The MTP head uses the same window so it sees
44+ // the same active context. 0 means full causal context.
45+ virtual int fa_window () const { return 0 ; }
46+
2247 // ── Target forward ──────────────────────────────────────────────
2348
2449 // Run a batch of tokens through the target model. Returns the argmax
@@ -33,6 +58,26 @@ struct DFlashTarget {
3358 int & last_tok,
3459 std::vector<int32_t > * all_argmax = nullptr ) = 0;
3560
61+ // Tree-structured verify: run a flat DFS-ordered DDTree through the
62+ // target with an ancestor-only attention mask. `flat_tokens[0]` is the
63+ // tree root (= last accepted token), `flat_tokens[1..N-1]` are the
64+ // DFS-ordered tree nodes (mirroring DDTree::token_ids). `tree.n_nodes`
65+ // must equal flat_tokens.size() - 1. On success, `out_argmax` is the
66+ // target's argmax at each of the N tree positions (size == N) and
67+ // (if non-null) `out_logits` is the raw logits laid out as
68+ // [N × vocab] floats. Returns false by default so existing targets
69+ // that haven't wired tree-verify can be detected by callers; concrete
70+ // targets override to plug in build_target_step_tree + the tree mask.
71+ virtual bool verify_tree (const std::vector<int32_t > & flat_tokens,
72+ const DDTree & tree,
73+ int base_pos,
74+ std::vector<int32_t > & out_argmax,
75+ std::vector<float > * out_logits = nullptr ) {
76+ (void )flat_tokens; (void )tree; (void )base_pos;
77+ (void )out_argmax; (void )out_logits;
78+ return false ;
79+ }
80+
3681 // ── KV state management ─────────────────────────────────────────
3782
3883 // Snapshot KV cache state before speculative verify, so it can be
@@ -42,6 +87,32 @@ struct DFlashTarget {
4287 // Restore KV cache to the last snapshot (undo speculative forward).
4388 virtual bool restore_kv () = 0;
4489
90+ // Rollback DeltaNet SSM/conv + full-attn KV to the accepted-path tail of
91+ // the most recent verify_tree() call. accepted_dfs[0] must be 0 (root).
92+ // Returns false if unsupported; callers must treat false as fatal in
93+ // multi-iteration tree-spec loops (poisoned KV/SSM otherwise).
94+ virtual bool restore_kv_at_dfs (const std::vector<int > & accepted_dfs) {
95+ (void )accepted_dfs;
96+ return false ;
97+ }
98+
99+ // Roll back DeltaNet SSM/conv + full-attn KV to slot `accept_n` of the
100+ // most recent verify_batch chain. Requires chain capture enabled.
101+ // Postcondition: cache cur_pos = base_pos + accept_n + 1.
102+ // Returns false if unsupported; chain runner falls back to snapshot+recommit.
103+ virtual bool restore_kv_at_chain (int accept_n) {
104+ (void )accept_n;
105+ return false ;
106+ }
107+
108+ // Enable per-position DeltaNet intermediate capture in verify_batch.
109+ // Off by default; unsafe when n_tokens > max_verify_tokens.
110+ virtual void enable_chain_capture (bool /* on*/ ) {}
111+
112+ // Record linear-chain topology before verify_batch so restore_kv_at_chain()
113+ // can locate the rollback slot. Must be called before each capturable iter.
114+ virtual void capture_topology_for_chain (int /* n_tokens*/ , int /* base_pos*/ ) {}
115+
45116 // ── Token utilities ─────────────────────────────────────────────
46117
47118 // Check if a token is end-of-sequence for this model.
@@ -61,6 +132,21 @@ struct DFlashTarget {
61132 int n_tokens,
62133 std::vector<int32_t > & tokens_out) = 0;
63134
135+ // Optional: project draft hidden states through the target's lm_head
136+ // and return the full raw logits (n_tokens * vocab floats) on host.
137+ // Used by MTP drafters that need a top-K surface for DDTree (the
138+ // argmax path above hides the distribution). Default returns false so
139+ // existing targets compile unchanged; concrete targets that wire it
140+ // up resize `logits_out` to n_tokens * vocab and return true. The
141+ // `out_vocab` param reports the vocab dim back to the caller.
142+ virtual bool project_hidden_to_logits (const float * /* hidden*/ ,
143+ int /* n_tokens*/ ,
144+ std::vector<float > & /* logits_out*/ ,
145+ int & out_vocab) {
146+ out_vocab = 0 ;
147+ return false ;
148+ }
149+
64150 // ── Configuration for draft model ───────────────────────────────
65151
66152 // Target's hidden dimension (draft model must match).
@@ -72,6 +158,56 @@ struct DFlashTarget {
72158 // Which target layers to capture intermediate activations from.
73159 // The draft model's fc layer expects exactly this many feature slices.
74160 virtual const std::vector<int > & capture_layer_ids () const = 0;
161+
162+ // Return the backbone's final post-norm hidden state for the last committed
163+ // token (hidden_size() floats, F32). Populated by verify_batch.
164+ // Returns nullptr if not yet available (e.g. before first verify_batch).
165+ // Default implementation returns nullptr; Qwen35DFlashTarget overrides it.
166+ virtual const float * last_hidden () const { return nullptr ; }
167+
168+ // Return the full post-norm hidden sequence from the MOST RECENT
169+ // verify_batch call: n_tokens * hidden_size() floats, F32, laid out as
170+ // [token_0_hidden, token_1_hidden, ..., token_{n_tokens-1}_hidden].
171+ // *out_n_tokens is set to the number of tokens captured (matches the
172+ // n_tokens passed to verify_batch). Default returns nullptr.
173+ virtual const float * last_hidden_seq (int * out_n_tokens) const {
174+ if (out_n_tokens) *out_n_tokens = 0 ;
175+ return nullptr ;
176+ }
177+
178+ // Return the post-norm hidden at an ABSOLUTE sequence position, if that
179+ // position is covered by the most recent verify_batch's hidden capture.
180+ // The Qwen3.6 MTP head needs h_{base_pos-1} for its input pair at each
181+ // chain step, which equals last_hidden() only on the first chain step
182+ // (right after prefill); subsequent steps need a hidden from earlier in
183+ // the most recent verify_batch chunk. Returns nullptr if out of range.
184+ virtual const float * hidden_at_pos (int abs_pos) const {
185+ (void )abs_pos;
186+ return nullptr ;
187+ }
188+
189+ // Pre-final-output-norm variant of hidden_at_pos. Mirrors llama.cpp
190+ // PR #22673's `t_h_pre_norm`. The Qwen3.6 MTP head's hnorm normalises
191+ // h_prev internally; feeding it the post-output-norm tensor double-
192+ // normalises and compounds per-depth rejection on D>=2 chains. Spec-
193+ // chain callers must prefer this accessor for the outer h_prev_0 seed
194+ // and fall back to hidden_at_pos() only if it returns nullptr (e.g.
195+ // adapters that do not yet capture the pre-norm sequence). Default
196+ // returns nullptr; Qwen35DFlashTarget overrides it when hidden-seq
197+ // capture is enabled.
198+ virtual const float * hidden_at_pos_pre_norm (int abs_pos) const {
199+ (void )abs_pos;
200+ return nullptr ;
201+ }
202+
203+ // Enable per-position post-norm + pre-norm hidden capture during the
204+ // next verify_batch calls. Default no-op; Qwen35DFlashTarget overrides.
205+ virtual void enable_hidden_seq_capture (bool /* on*/ ) {}
206+
207+ // FULL_SEQ during prefill (warm_head_kv reads per-position); LAST_ROW_ONLY
208+ // during decode-side chain verifies. Default no-op.
209+ enum class VerifyCaptureScope { FULL_SEQ, LAST_ROW_ONLY };
210+ virtual void set_hidden_capture_scope (VerifyCaptureScope /* scope*/ ) {}
75211};
76212
77213} // namespace dflash27b
0 commit comments