feat(dflash): native Qwen3.6 MTP (NextN) runtime + contract test#153
Open
javierpazo wants to merge 1 commit into
Open
feat(dflash): native Qwen3.6 MTP (NextN) runtime + contract test#153javierpazo wants to merge 1 commit into
javierpazo wants to merge 1 commit into
Conversation
Adds the runtime side of native multi-token prediction so dflash can load
Qwen3.6-MTP GGUFs (am17an-style, llama.cpp PR #22673 tensor convention)
and run a target-trunk + NextN-block forward in the same ggml graph.
Library
- src/f16_convert.cu — small bf16/f16 → f32 widen kernels (used by the
MTP token-embedding widen and shared with the rollback path).
- src/internal.h — new types: TargetNextN, TargetMtpLayer, TargetMtpCache,
QwenMtpGraphInputs/Outputs; QwenGraphInputs gains expose_pre_norm_hidden,
QwenGraphOutputs gains pre_norm_hidden; TargetWeights gains mtp_layers,
nextn_predict_layers, gguf_block_count, tok_embd_gpu. No fields removed
from the existing trunk API.
- src/qwen35_target_graph.cpp — create/free/reset_target_mtp_cache and
build_qwen35_mtp_graph (RMSNorm e || RMSNorm h → eh_proj → full-attn
block → SwiGLU FFN → shared head, falling back to trunk out_norm/output
when nextn.shared_head_* are absent). Wires expose_pre_norm_hidden in
build_qwen35_graph.
- src/gguf_target_loader.cpp — reads qwen35.nextn_predict_layers, splits
block_count into trunk + tail, loads blk.<i>.nextn.* tensors into
TargetWeights::mtp_layers, and uploads token_embd.weight to the GPU
for MTP checkpoints so MTP can chain proposals device-side
(DFLASH27B_UPLOAD_TOK_EMBD overrides).
Tests
- test/test_mtp_graph_contract.cpp — synthetic-tensor contract test that
asserts build_qwen35_mtp_graph wires together correctly. No GPU model
needed (~49 nodes, runs in milliseconds). Suitable for CI.
- test/smoke_mtp_graph.cpp — loads a real MTP GGUF, builds the NextN
graph for one token, asserts the output is finite.
- test/smoke_target_mtp_handoff.cpp — loads a real MTP GGUF and runs
target + MTP in the SAME ggml_cgraph, proving the pre_norm_hidden
handoff lives entirely on-device.
- test/smoke_mtp_integrated_decode.cpp — minimal greedy decode loop:
target greedy + MTP greedy, accept/correct counters, tok/s summary.
Functional baseline for the upcoming speculative loop.
CMake
- f16_convert.cu added to the dflash27b library sources.
- Four new test targets registered (test_mtp_graph_contract + three
smokes). Linked against dflash27b + ggml + ggml-cuda + CUDA::cudart.
Validation on RTX 6000 Ada (sm_89), Qwen3.6-27B-MTP Q4_K_M:
test_mtp_graph_contract → PASS (49 graph nodes, shapes correct)
smoke_mtp_graph → PASS (logits 0 NaN, 0 Inf, [-24.3, 14.4])
smoke_target_mtp_handoff → PASS (3061 nodes, both heads clean)
smoke_mtp_integrated_decode (8 tokens)
→ PASS, 50% greedy acceptance, 23.6 tok/s
Honest scope
- MoE MTP is not supported in this PR (build_qwen35_mtp_graph fails fast
with a clear message). The 35B-A3B MTP GGUFs need the MoE TargetLayer
fields that howard0su is upstreaming in Luce-Org#120. A MoE-aware MTP graph
is a one-line dispatch on top of this PR once Luce-Org#120 merges.
- This PR ships the runtime contract only. The integrated speculative
decode loop (chain-2 / tree-fused / immediate-bonus), the daemon-side
--mtp-integrated wiring, and the mtp_baseline_gate.py parity harness
land in a follow-up PR. Measured locally on the same MTP GGUF with MTP
disabled vs enabled (chain-2, n_gen=256) we see +36% tok/s today — but
the speculative loop driving that number is not in this PR yet.
Compatible GGUFs include am17an/Qwen3.6-27B-MTP-GGUF and the havenoammo /
froggeric Unsloth UD repacks; tensor naming follows llama.cpp #22673.
There was a problem hiding this comment.
2 issues found across 10 files
Prompt for AI agents (unresolved issues)
Check if these issues are valid — if so, understand the root cause of each and fix them. If appropriate, use sub-agents to investigate and fix each issue separately.
<file name="dflash/src/f16_convert.cu">
<violation number="1" location="dflash/src/f16_convert.cu:37">
P2: Missing zero-length guard causes an invalid zero-grid CUDA launch when `n_elems == 0`.</violation>
<violation number="2" location="dflash/src/f16_convert.cu:47">
P2: Missing zero-length guard causes an invalid zero-grid CUDA launch when `n_elems == 0`.</violation>
</file>
Reply with feedback, questions, or to request a fix. Tag @cubic-dev-ai to re-run a review.
| cudaStream_t stream) { | ||
| const int threads = 256; | ||
| const int blocks = (int)((n_elems + threads - 1) / threads); | ||
| bf16_to_f32_kernel<<<blocks, threads, 0, stream>>>( |
There was a problem hiding this comment.
P2: Missing zero-length guard causes an invalid zero-grid CUDA launch when n_elems == 0.
Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At dflash/src/f16_convert.cu, line 47:
<comment>Missing zero-length guard causes an invalid zero-grid CUDA launch when `n_elems == 0`.</comment>
<file context>
@@ -0,0 +1,49 @@
+ cudaStream_t stream) {
+ const int threads = 256;
+ const int blocks = (int)((n_elems + threads - 1) / threads);
+ bf16_to_f32_kernel<<<blocks, threads, 0, stream>>>(
+ (const __nv_bfloat16 *)src, (float *)dst, n_elems);
+}
</file context>
| cudaStream_t stream) { | ||
| const int threads = 256; | ||
| const int blocks = (int)((n_elems + threads - 1) / threads); | ||
| f16_to_f32_kernel<<<blocks, threads, 0, stream>>>( |
There was a problem hiding this comment.
P2: Missing zero-length guard causes an invalid zero-grid CUDA launch when n_elems == 0.
Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At dflash/src/f16_convert.cu, line 37:
<comment>Missing zero-length guard causes an invalid zero-grid CUDA launch when `n_elems == 0`.</comment>
<file context>
@@ -0,0 +1,49 @@
+ cudaStream_t stream) {
+ const int threads = 256;
+ const int blocks = (int)((n_elems + threads - 1) / threads);
+ f16_to_f32_kernel<<<blocks, threads, 0, stream>>>(
+ (const __half *)src, (float *)dst, n_elems);
+}
</file context>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds the runtime side of native multi-token prediction so
dflashcan load Qwen3.6-MTP GGUFs (am17an-style, llama.cpp PR #22673 tensor convention) and run a target-trunk + NextN-block forward in the sameggml_cgraph. Branch rebased on top of freshmain(after #129, #132, #145, #148).This PR is the runtime contract only. It does not yet ship the speculative decode loop that turns MTP into an actual decode speedup — that needs the
target_verifygraph-bucket cache (Phase 4 ofdflash/docs/HANDOFF_MTP_ACCELERATION_2026-05-11.md, "Waves A-D"), which is where ~81% of decode wall time currently lives. The linear integrated decode CLI lands in #154; the bucket-cached fast path is the work after that.What lands
Library
src/f16_convert.cu— tinybf16 → f32/f16 → f32widen kernels. Used by the MTP token-embedding widen and reusable from the existing rollback path.src/internal.h:TargetNextN,TargetMtpLayer,TargetMtpCache,QwenMtpGraphInputs,QwenMtpGraphOutputsQwenGraphInputs::expose_pre_norm_hidden,QwenGraphOutputs::pre_norm_hiddenfor the trunk → MTP tensor handoffTargetWeights::mtp_layers,nextn_predict_layers,gguf_block_count,tok_embd_gpuTargetLoadPlan,load_target_gguf_partial,create_target_cache_partial,kv_k_rotated,migrate_prefill_cacheall preserved as upstream / feat(dflash): P2P opt-in, host-staged GPU copy, sharded target daemon #148 expects).src/qwen35_target_graph.cpp:create_target_mtp_cache/free_target_mtp_cache/reset_target_mtp_cachebuild_qwen35_mtp_graph(RMSNorm e || RMSNorm h → eh_proj → full-attn block → SwiGLU FFN → shared head, with fallback to trunkout_norm/outputwhennextn.shared_head_*are absent — matches am17an pack layout)expose_pre_norm_hiddenwired intobuild_qwen35_graph(marks the final pre-norm hidden as a graph output, no CPU roundtrip)src/gguf_target_loader.cpp:qwen35.nextn_predict_layers, splitsblock_countinto trunk + MTP tailblk.<i>.nextn.eh_proj.weight,enorm,hnorm, optionalshared_head_head/shared_head_norm/embed_tokenstoken_embd.weightto the GPU for MTP checkpoints (MTP chains proposals device-side); env varDFLASH27B_UPLOAD_TOK_EMBDoverridesTests
test/test_mtp_graph_contract.cpp— synthetic-tensor contract test. No GPU model needed; ~49 nodes, runs in milliseconds. Suitable for CI.test/smoke_mtp_graph.cpp— loads a real MTP GGUF, builds the NextN graph for one token, asserts the output is finite.test/smoke_target_mtp_handoff.cpp— loads a real MTP GGUF and runs target + MTP in the sameggml_cgraph, proving thepre_norm_hiddenhandoff stays on-device.test/smoke_mtp_integrated_decode.cpp— minimal greedy decode loop: target greedy + MTP greedy, accept/correct counters, tok/s summary. Functional baseline for the speculative loop.CMake
f16_convert.cuadded to thedflash27blibrary sources.test_mtp_graph_contract+ three smokes), linked againstdflash27b + ggml + ggml-cuda + CUDA::cudart.Validation (RTX 6000 Ada sm_89, Qwen3.6-27B-MTP Q4_K_M, FA_WINDOW=0)
test_mtp_graph_contract[32,1], hidden[8,1])smoke_mtp_graph[-24.28, 14.43]smoke_target_mtp_handoffsmoke_mtp_integrated_decode(8 tok)Loader log:
target loaded: 866 tensors on GPU 15.35 GiB, tok_embd 682 MiB GPU+CPU (q4_K), trunk_layers=64 nextn=1.GGUF compatibility
Tensor naming follows the convention introduced by llama.cpp PR #22673. Confirmed compatible with:
am17an/Qwen3.6-27B-MTP-GGUF(reference)havenoammo/Qwen3.6-27B-MTP-UD-GGUF(Unsloth UD repack with am17an MTP grafted)froggeric/Qwen3.6-27B-MTP-GGUFExpected tail-block tensors:
When the optional shared-head tensors are absent the runtime falls back to the trunk's
output_norm/output(matches the am17an pack).Honest scope on speedup
This PR does not claim any speedup over DFlash-classic. Early measurements compared "same MTP GGUF, MTP off" vs "same MTP GGUF, MTP chain-2" — that is not a fair baseline because the MTP-off path on an MTP GGUF still inherits the on-GPU
token_embdupload and the trunk-minus-NextN layout, neither of which DFlash-classic uses. A fair comparison against DFlash-classic + PFlash on a plain Qwen3.6-27B Q4_K_M (which is what the repo ships today) needs thetarget_verifygraph-bucket cache work described as "Waves A-D" indflash/docs/HANDOFF_MTP_ACCELERATION_2026-05-11.md:graph_computeis ~81% of decode wall time, and without those buckets MTP's extra forward never amortizes.So this PR ships the runtime contract and the smoke tests that pin it. The linear decode CLI is in #154. The publishable speedup PR comes after the bucket cache lands.
MoE limitation
build_qwen35_mtp_graphcurrently implements the dense-FFN path only. The 35B-A3B MTP GGUFs need the MoETargetLayerfields and the routed FFN path that howard0su is upstreaming in #120 "Qwen3.5 MoE support". A MoE-aware MTP graph is a small dispatch on top of this PR once #120 lands. Until then, loading a MoE-MTP GGUF + invoking the MTP graph returns a clear error rather than wrong output.Follow-ups
target_verifygraph-bucket cache (Phase 4 / "Waves A-D" of the HANDOFF doc). This is the high-value perf PR.--mtp-integratedwiring re-usingprefix_cache.py.Verification vs existing community PRs
lucebox-hubfor native MTP / NextN runtime support.Javier Pazó — @xabicasa — xabicasa@gmail.com