-
Notifications
You must be signed in to change notification settings - Fork 197
feat(dflash): native Qwen3.6 MTP (NextN) runtime + contract test #153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
javierpazo
wants to merge
1
commit into
Luce-Org:main
Choose a base branch
from
javierpazo:xabicasa/dflash-mtp-integrated
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| # Native MTP (NextN) — runtime status, 2026-05-11 | ||
|
|
||
| This document describes the native multi-token prediction (MTP / NextN) | ||
| runtime introduced into `dflash` in PR `feat(dflash): native Qwen3.6 MTP | ||
| integrated decode`. It tracks the contract, what already works, and what | ||
| remains for the next PR before MTP becomes a default-on decode mode. | ||
|
|
||
| ## What this PR ships | ||
|
|
||
| - `dflash/src/f16_convert.cu` — small `f16/bf16 → f32` widen kernels used | ||
| by both the rollback path and the MTP token-embedding widen. | ||
| - `dflash/src/internal.h` — new types: | ||
| - `TargetNextN`, `TargetMtpLayer` | ||
| - `TargetMtpCache` (KV cache for the NextN tail block only) | ||
| - `QwenMtpGraphInputs`, `QwenMtpGraphOutputs` | ||
| - `expose_pre_norm_hidden` on `QwenGraphInputs` | ||
| - `pre_norm_hidden` on `QwenGraphOutputs` | ||
| - `TargetWeights::mtp_layers`, `nextn_predict_layers`, `gguf_block_count`, | ||
| `tok_embd_gpu` (no fields removed; the trunk API is preserved). | ||
| - `dflash/src/qwen35_target_graph.cpp` — four new functions: | ||
| - `create_target_mtp_cache` / `free_target_mtp_cache` / `reset_target_mtp_cache` | ||
| - `build_qwen35_mtp_graph` — RMSNorm(e) || RMSNorm(h) → `eh_proj` → | ||
| full-attention transformer block → SwiGLU FFN → shared head. | ||
| Also wires `expose_pre_norm_hidden` into `build_qwen35_graph`. | ||
| - `dflash/src/gguf_target_loader.cpp` — reads `qwen35.nextn_predict_layers`, | ||
| splits the GGUF blocks into trunk + MTP tail, loads `blk.<i>.nextn.*` | ||
| tensors into `TargetWeights::mtp_layers`, and uploads `token_embd.weight` | ||
| to the GPU when the checkpoint carries MTP (`DFLASH27B_UPLOAD_TOK_EMBD` | ||
| env var overrides). | ||
| - `dflash/test/test_mtp_graph_contract.cpp` — synthetic-tensor test that | ||
| asserts the MTP graph wires together correctly. No GPU model needed; | ||
| cheap to run in CI. | ||
| - `dflash/test/smoke_mtp_graph.cpp` — loads a real MTP GGUF, builds the | ||
| NextN graph for a single token, and validates the output is finite. | ||
| - `dflash/test/smoke_target_mtp_handoff.cpp` — loads a real MTP GGUF and | ||
| proves that the trunk pre-norm hidden tensor feeds directly into the | ||
| MTP block within the same `ggml_cgraph` (no CPU roundtrip required). | ||
| - `dflash/test/smoke_mtp_integrated_decode.cpp` — full integrated decode | ||
| loop: target greedy + MTP greedy in one graph, with per-step accept / | ||
| correct counters. This is the functional baseline the upcoming PR's | ||
| speculative loop will be built on top of. | ||
|
|
||
| ## GGUF compatibility | ||
|
|
||
| The loader follows the tensor naming convention introduced by llama.cpp's | ||
| [MTP PR #22673](https://github.com/ggml-org/llama.cpp/pull/22673). It is | ||
| compatible with the reference Qwen3.6-MTP GGUFs published on the Hub: | ||
|
|
||
| - `am17an/Qwen3.6-27B-MTP-GGUF` | ||
| - `am17an/Qwen3.6-35BA3B-MTP-GGUF` (MoE — see "MoE limitation" below) | ||
| - `havenoammo/Qwen3.6-27B-MTP-UD-GGUF` | ||
| - `havenoammo/Qwen3.6-35B-A3B-MTP-GGUF` | ||
| - `froggeric/Qwen3.6-27B-MTP-GGUF` | ||
|
|
||
| The expected tail-block tensor names are: | ||
|
|
||
| ```text | ||
| blk.<n_trunk>.nextn.eh_proj.weight [2 * hidden, hidden] | ||
| blk.<n_trunk>.nextn.embed_tokens.weight [hidden, vocab] (optional) | ||
| blk.<n_trunk>.nextn.enorm.weight [hidden] | ||
| blk.<n_trunk>.nextn.hnorm.weight [hidden] | ||
| blk.<n_trunk>.nextn.shared_head_head.weight [hidden, vocab] (optional) | ||
| blk.<n_trunk>.nextn.shared_head_norm.weight [hidden] (optional) | ||
| ``` | ||
|
|
||
| When the optional shared-head tensors are absent the runtime falls back to | ||
| the trunk's `output_norm` / `output` (lm_head), matching how am17an's | ||
| GGUFs are typically packed. | ||
|
|
||
| ## MoE limitation | ||
|
|
||
| `build_qwen35_mtp_graph` currently implements the dense-FFN path only. The | ||
| 35B-A3B MTP GGUFs require the MoE `TargetLayer` fields and the routed | ||
| FFN path that howard0su is upstreaming in | ||
| [PR #120 "Qwen3.5 MoE support"](https://github.com/Luce-Org/lucebox-hub/pull/120). | ||
| A MoE-aware `build_qwen35_mtp_graph` is a one-line dispatch on top of | ||
| this PR once #120 lands. Until then, loading a MoE-MTP GGUF + invoking | ||
| the MTP graph returns a clear error rather than producing wrong output. | ||
|
|
||
| ## Why MTP is opt-in, not default-on | ||
|
|
||
| Measured today against `DFlash + PFlash` on the same MTP GGUF with MTP | ||
| disabled, on a single RTX 6000 Ada (sm_89), Qwen3.6-27B Q4_K_M target, | ||
| `q4_0/q4_0` KV, FA_WINDOW=0, DDTree budget=16, draft feature mirror on: | ||
|
|
||
| | n_gen | Same GGUF, MTP off (tok/s) | Same GGUF, MTP chain-2 (tok/s) | Δ | | ||
| |---:|---:|---:|---:| | ||
| | 64 | 57.58 | 54.72 | **−5.0%** | | ||
| | 128 | 67.58 | 64.23 | **−5.0%** | | ||
| | 256 | 60.40 | 82.18 | **+36.1%** | | ||
|
|
||
| What changes between 64 and 256 tokens is that DDTree rounds drop from | ||
| roughly 60 → 38 and average tokens committed per draft step rise from | ||
| 4.27 → 6.74, so the extra MTP forward starts paying for itself. | ||
|
|
||
| This is real but workload-dependent acceleration, not a universal default. | ||
| The next PR adds the speculative loop that turns this into a default-on | ||
| mode for long generations; today's PR ships only the runtime contract and | ||
| the tests that pin it. | ||
|
|
||
| ## Known follow-ups (next PR) | ||
|
|
||
| 1. Speculative decode loop wiring (`run_mtp_integrated_prompt`, | ||
| target-batched verify, fast rollback) inside `test_dflash`. | ||
| 2. Daemon-side `--mtp-integrated` CLI + metrics surface (`[mtp-daemon]` | ||
| line, `last_mtp` aggregated in `prefix_cache.py`). | ||
| 3. `mtp_baseline_gate.py` published as a reusable parity gate harness. | ||
| 4. CPU hidden-readback elimination — the current functional smoke still | ||
| round-trips token ids through CPU between MTP steps. Removing that is | ||
| the highest-value perf fix and is queued behind CUDA-graph capture. | ||
| 5. MoE MTP path after PR #120 merges. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| // Tiny half-precision → f32 conversion kernels used by the DDtree rollback | ||
| // path and the drafter's target_feat widen. We store some tensors | ||
| // (ssm_intermediate, target_feat) at 16-bit to halve their memory footprint, | ||
| // and widen on read into f32 consumers. | ||
| // | ||
| // Exposes plain C entry points so test_dflash.cpp can call them without | ||
| // pulling in a CUDA compile unit of its own. | ||
|
|
||
| #include <cuda_runtime.h> | ||
| #include <cuda_fp16.h> | ||
| #include <cuda_bf16.h> | ||
|
|
||
| static __global__ void f16_to_f32_kernel(const __half * __restrict__ src, | ||
| float * __restrict__ dst, | ||
| size_t n_elems) { | ||
| const size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x; | ||
| if (i < n_elems) { | ||
| dst[i] = __half2float(src[i]); | ||
| } | ||
| } | ||
|
|
||
| static __global__ void bf16_to_f32_kernel(const __nv_bfloat16 * __restrict__ src, | ||
| float * __restrict__ dst, | ||
| size_t n_elems) { | ||
| const size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x; | ||
| if (i < n_elems) { | ||
| dst[i] = __bfloat162float(src[i]); | ||
| } | ||
| } | ||
|
|
||
| extern "C" void dflash27b_launch_f16_to_f32(const void * src, | ||
| void * dst, | ||
| size_t n_elems, | ||
| cudaStream_t stream) { | ||
| const int threads = 256; | ||
| const int blocks = (int)((n_elems + threads - 1) / threads); | ||
| f16_to_f32_kernel<<<blocks, threads, 0, stream>>>( | ||
| (const __half *)src, (float *)dst, n_elems); | ||
| } | ||
|
|
||
| extern "C" void dflash27b_launch_bf16_to_f32(const void * src, | ||
| void * dst, | ||
| size_t n_elems, | ||
| cudaStream_t stream) { | ||
| const int threads = 256; | ||
| const int blocks = (int)((n_elems + threads - 1) / threads); | ||
| bf16_to_f32_kernel<<<blocks, threads, 0, stream>>>( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P2: Missing zero-length guard causes an invalid zero-grid CUDA launch when Prompt for AI agents |
||
| (const __nv_bfloat16 *)src, (float *)dst, n_elems); | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P2: Missing zero-length guard causes an invalid zero-grid CUDA launch when
n_elems == 0.Prompt for AI agents