Skip to content

feat(dflash): native Qwen3.6 MTP (NextN) runtime + contract test#153

Open
javierpazo wants to merge 1 commit into
Luce-Org:mainfrom
javierpazo:xabicasa/dflash-mtp-integrated
Open

feat(dflash): native Qwen3.6 MTP (NextN) runtime + contract test#153
javierpazo wants to merge 1 commit into
Luce-Org:mainfrom
javierpazo:xabicasa/dflash-mtp-integrated

Conversation

@javierpazo
Copy link
Copy Markdown
Contributor

@javierpazo javierpazo commented May 11, 2026

Summary

Adds the runtime side of native multi-token prediction so dflash can load Qwen3.6-MTP GGUFs (am17an-style, llama.cpp PR #22673 tensor convention) and run a target-trunk + NextN-block forward in the same ggml_cgraph. Branch rebased on top of fresh main (after #129, #132, #145, #148).

This PR is the runtime contract only. It does not yet ship the speculative decode loop that turns MTP into an actual decode speedup — that needs the target_verify graph-bucket cache (Phase 4 of dflash/docs/HANDOFF_MTP_ACCELERATION_2026-05-11.md, "Waves A-D"), which is where ~81% of decode wall time currently lives. The linear integrated decode CLI lands in #154; the bucket-cached fast path is the work after that.

What lands

Library

  • src/f16_convert.cu — tiny bf16 → f32 / f16 → f32 widen kernels. Used by the MTP token-embedding widen and reusable from the existing rollback path.
  • src/internal.h:
    • new types: TargetNextN, TargetMtpLayer, TargetMtpCache, QwenMtpGraphInputs, QwenMtpGraphOutputs
    • QwenGraphInputs::expose_pre_norm_hidden, QwenGraphOutputs::pre_norm_hidden for the trunk → MTP tensor handoff
    • TargetWeights::mtp_layers, nextn_predict_layers, gguf_block_count, tok_embd_gpu
    • No fields removed from the existing trunk API (TargetLoadPlan, load_target_gguf_partial, create_target_cache_partial, kv_k_rotated, migrate_prefill_cache all preserved as upstream / feat(dflash): P2P opt-in, host-staged GPU copy, sharded target daemon #148 expects).
  • src/qwen35_target_graph.cpp:
    • create_target_mtp_cache / free_target_mtp_cache / reset_target_mtp_cache
    • build_qwen35_mtp_graph (RMSNorm e || RMSNorm h → eh_proj → full-attn block → SwiGLU FFN → shared head, with fallback to trunk out_norm / output when nextn.shared_head_* are absent — matches am17an pack layout)
    • expose_pre_norm_hidden wired into build_qwen35_graph (marks the final pre-norm hidden as a graph output, no CPU roundtrip)
  • src/gguf_target_loader.cpp:
    • reads qwen35.nextn_predict_layers, splits block_count into trunk + MTP tail
    • loads blk.<i>.nextn.eh_proj.weight, enorm, hnorm, optional shared_head_head / shared_head_norm / embed_tokens
    • uploads token_embd.weight to the GPU for MTP checkpoints (MTP chains proposals device-side); env var DFLASH27B_UPLOAD_TOK_EMBD overrides

Tests

  • test/test_mtp_graph_contract.cpp — synthetic-tensor contract test. No GPU model needed; ~49 nodes, runs in milliseconds. Suitable for CI.
  • test/smoke_mtp_graph.cpp — loads a real MTP GGUF, builds the NextN graph for one token, asserts the output is finite.
  • test/smoke_target_mtp_handoff.cpp — loads a real MTP GGUF and runs target + MTP in the same ggml_cgraph, proving the pre_norm_hidden handoff stays on-device.
  • test/smoke_mtp_integrated_decode.cpp — minimal greedy decode loop: target greedy + MTP greedy, accept/correct counters, tok/s summary. Functional baseline for the speculative loop.

CMake

  • f16_convert.cu added to the dflash27b library sources.
  • Four new test targets registered (test_mtp_graph_contract + three smokes), linked against dflash27b + ggml + ggml-cuda + CUDA::cudart.

Validation (RTX 6000 Ada sm_89, Qwen3.6-27B-MTP Q4_K_M, FA_WINDOW=0)

Binary Result
test_mtp_graph_contract PASS — 49 graph nodes, shapes correct (logits [32,1], hidden [8,1])
smoke_mtp_graph PASS — logits 0 NaN / 0 Inf, range [-24.28, 14.43]
smoke_target_mtp_handoff PASS — 3061 nodes, both target + MTP heads clean (no NaN/Inf)
smoke_mtp_integrated_decode (8 tok) PASS — 50% greedy acceptance, 23.6 tok/s

Loader log: target loaded: 866 tensors on GPU 15.35 GiB, tok_embd 682 MiB GPU+CPU (q4_K), trunk_layers=64 nextn=1.

GGUF compatibility

Tensor naming follows the convention introduced by llama.cpp PR #22673. Confirmed compatible with:

Expected tail-block tensors:

blk.<n_trunk>.nextn.eh_proj.weight              [2 * hidden, hidden]
blk.<n_trunk>.nextn.embed_tokens.weight          [hidden, vocab]    (optional)
blk.<n_trunk>.nextn.enorm.weight                 [hidden]
blk.<n_trunk>.nextn.hnorm.weight                 [hidden]
blk.<n_trunk>.nextn.shared_head_head.weight      [hidden, vocab]    (optional)
blk.<n_trunk>.nextn.shared_head_norm.weight      [hidden]           (optional)

When the optional shared-head tensors are absent the runtime falls back to the trunk's output_norm / output (matches the am17an pack).

Honest scope on speedup

This PR does not claim any speedup over DFlash-classic. Early measurements compared "same MTP GGUF, MTP off" vs "same MTP GGUF, MTP chain-2" — that is not a fair baseline because the MTP-off path on an MTP GGUF still inherits the on-GPU token_embd upload and the trunk-minus-NextN layout, neither of which DFlash-classic uses. A fair comparison against DFlash-classic + PFlash on a plain Qwen3.6-27B Q4_K_M (which is what the repo ships today) needs the target_verify graph-bucket cache work described as "Waves A-D" in dflash/docs/HANDOFF_MTP_ACCELERATION_2026-05-11.md: graph_compute is ~81% of decode wall time, and without those buckets MTP's extra forward never amortizes.

So this PR ships the runtime contract and the smoke tests that pin it. The linear decode CLI is in #154. The publishable speedup PR comes after the bucket cache lands.

MoE limitation

build_qwen35_mtp_graph currently implements the dense-FFN path only. The 35B-A3B MTP GGUFs need the MoE TargetLayer fields and the routed FFN path that howard0su is upstreaming in #120 "Qwen3.5 MoE support". A MoE-aware MTP graph is a small dispatch on top of this PR once #120 lands. Until then, loading a MoE-MTP GGUF + invoking the MTP graph returns a clear error rather than wrong output.

Follow-ups

  1. Linear MTP integrated decode CLI — feat(dflash): linear native MTP integrated decode CLI (stacked on #153) #154 (stacked on this PR).
  2. target_verify graph-bucket cache (Phase 4 / "Waves A-D" of the HANDOFF doc). This is the high-value perf PR.
  3. Daemon-mode --mtp-integrated wiring re-using prefix_cache.py.
  4. MoE MTP path after Qwen3.5 MoE support #120 merges.

Verification vs existing community PRs


Javier Pazó@xabicasaxabicasa@gmail.com

Adds the runtime side of native multi-token prediction so dflash can load
Qwen3.6-MTP GGUFs (am17an-style, llama.cpp PR #22673 tensor convention)
and run a target-trunk + NextN-block forward in the same ggml graph.

Library
- src/f16_convert.cu — small bf16/f16 → f32 widen kernels (used by the
  MTP token-embedding widen and shared with the rollback path).
- src/internal.h — new types: TargetNextN, TargetMtpLayer, TargetMtpCache,
  QwenMtpGraphInputs/Outputs; QwenGraphInputs gains expose_pre_norm_hidden,
  QwenGraphOutputs gains pre_norm_hidden; TargetWeights gains mtp_layers,
  nextn_predict_layers, gguf_block_count, tok_embd_gpu. No fields removed
  from the existing trunk API.
- src/qwen35_target_graph.cpp — create/free/reset_target_mtp_cache and
  build_qwen35_mtp_graph (RMSNorm e || RMSNorm h → eh_proj → full-attn
  block → SwiGLU FFN → shared head, falling back to trunk out_norm/output
  when nextn.shared_head_* are absent). Wires expose_pre_norm_hidden in
  build_qwen35_graph.
- src/gguf_target_loader.cpp — reads qwen35.nextn_predict_layers, splits
  block_count into trunk + tail, loads blk.<i>.nextn.* tensors into
  TargetWeights::mtp_layers, and uploads token_embd.weight to the GPU
  for MTP checkpoints so MTP can chain proposals device-side
  (DFLASH27B_UPLOAD_TOK_EMBD overrides).

Tests
- test/test_mtp_graph_contract.cpp — synthetic-tensor contract test that
  asserts build_qwen35_mtp_graph wires together correctly. No GPU model
  needed (~49 nodes, runs in milliseconds). Suitable for CI.
- test/smoke_mtp_graph.cpp — loads a real MTP GGUF, builds the NextN
  graph for one token, asserts the output is finite.
- test/smoke_target_mtp_handoff.cpp — loads a real MTP GGUF and runs
  target + MTP in the SAME ggml_cgraph, proving the pre_norm_hidden
  handoff lives entirely on-device.
- test/smoke_mtp_integrated_decode.cpp — minimal greedy decode loop:
  target greedy + MTP greedy, accept/correct counters, tok/s summary.
  Functional baseline for the upcoming speculative loop.

CMake
- f16_convert.cu added to the dflash27b library sources.
- Four new test targets registered (test_mtp_graph_contract + three
  smokes). Linked against dflash27b + ggml + ggml-cuda + CUDA::cudart.

Validation on RTX 6000 Ada (sm_89), Qwen3.6-27B-MTP Q4_K_M:
  test_mtp_graph_contract  → PASS (49 graph nodes, shapes correct)
  smoke_mtp_graph          → PASS (logits 0 NaN, 0 Inf, [-24.3, 14.4])
  smoke_target_mtp_handoff → PASS (3061 nodes, both heads clean)
  smoke_mtp_integrated_decode (8 tokens)
                           → PASS, 50% greedy acceptance, 23.6 tok/s

Honest scope
- MoE MTP is not supported in this PR (build_qwen35_mtp_graph fails fast
  with a clear message). The 35B-A3B MTP GGUFs need the MoE TargetLayer
  fields that howard0su is upstreaming in Luce-Org#120. A MoE-aware MTP graph
  is a one-line dispatch on top of this PR once Luce-Org#120 merges.
- This PR ships the runtime contract only. The integrated speculative
  decode loop (chain-2 / tree-fused / immediate-bonus), the daemon-side
  --mtp-integrated wiring, and the mtp_baseline_gate.py parity harness
  land in a follow-up PR. Measured locally on the same MTP GGUF with MTP
  disabled vs enabled (chain-2, n_gen=256) we see +36% tok/s today — but
  the speculative loop driving that number is not in this PR yet.

Compatible GGUFs include am17an/Qwen3.6-27B-MTP-GGUF and the havenoammo /
froggeric Unsloth UD repacks; tensor naming follows llama.cpp #22673.
Copy link
Copy Markdown

@cubic-dev-ai cubic-dev-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 issues found across 10 files

Prompt for AI agents (unresolved issues)

Check if these issues are valid — if so, understand the root cause of each and fix them. If appropriate, use sub-agents to investigate and fix each issue separately.


<file name="dflash/src/f16_convert.cu">

<violation number="1" location="dflash/src/f16_convert.cu:37">
P2: Missing zero-length guard causes an invalid zero-grid CUDA launch when `n_elems == 0`.</violation>

<violation number="2" location="dflash/src/f16_convert.cu:47">
P2: Missing zero-length guard causes an invalid zero-grid CUDA launch when `n_elems == 0`.</violation>
</file>

Reply with feedback, questions, or to request a fix. Tag @cubic-dev-ai to re-run a review.

Comment thread dflash/src/f16_convert.cu
cudaStream_t stream) {
const int threads = 256;
const int blocks = (int)((n_elems + threads - 1) / threads);
bf16_to_f32_kernel<<<blocks, threads, 0, stream>>>(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: Missing zero-length guard causes an invalid zero-grid CUDA launch when n_elems == 0.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At dflash/src/f16_convert.cu, line 47:

<comment>Missing zero-length guard causes an invalid zero-grid CUDA launch when `n_elems == 0`.</comment>

<file context>
@@ -0,0 +1,49 @@
+                                             cudaStream_t stream) {
+    const int threads = 256;
+    const int blocks  = (int)((n_elems + threads - 1) / threads);
+    bf16_to_f32_kernel<<<blocks, threads, 0, stream>>>(
+        (const __nv_bfloat16 *)src, (float *)dst, n_elems);
+}
</file context>

Comment thread dflash/src/f16_convert.cu
cudaStream_t stream) {
const int threads = 256;
const int blocks = (int)((n_elems + threads - 1) / threads);
f16_to_f32_kernel<<<blocks, threads, 0, stream>>>(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: Missing zero-length guard causes an invalid zero-grid CUDA launch when n_elems == 0.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At dflash/src/f16_convert.cu, line 37:

<comment>Missing zero-length guard causes an invalid zero-grid CUDA launch when `n_elems == 0`.</comment>

<file context>
@@ -0,0 +1,49 @@
+                                            cudaStream_t stream) {
+    const int threads = 256;
+    const int blocks  = (int)((n_elems + threads - 1) / threads);
+    f16_to_f32_kernel<<<blocks, threads, 0, stream>>>(
+        (const __half *)src, (float *)dst, n_elems);
+}
</file context>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant