Skip to content

Commit 74f981e

Browse files
Brooooooklynclaude
andauthored
perf(qwen3.5): CUDA GDN chunked-ops prefill (~1.3–1.75× TTFT, all Qwen3.6 classes) (mlx-node#72)
## Summary Collapses the **O(T) per-step gated-delta (GDN) recurrence** on the CUDA (non-Metal) prefill path into **O(T/BT) chunk-serial batched matmuls** (cuBLAS / tensor cores) — a device-agnostic, pure-`MxArray` port of the in-tree Metal chunked kernel (`crates/mlx-sys/src/metal/gated_delta_chunked.metal.inc`). - **Zero build changes** (no nvcc/NVRTC) — matmuls route through cuBLAS. - **Default-on** for the CUDA ops path; `MLX_GDN_KERNEL=perstep` reverts for same-binary A/B. - **No Metal impact** — `use_kernel=true` never reaches this path; the Mac/Metal production path is byte-identical. This attacks the *"GDN per-step recurrence is the prefill floor"* bottleneck that [PR mlx-node#71](mlx-node#71 benchmark flagged as the **#1 dense lever**. ## Measured (GB10 / DGX Spark, Qwen3.6, warm, prefill TTFT vs per-step) | model | 1577-tok speedup | parity | |--------------|------------------|--------| | dense-Q4 | **1.62×** | byte-identical / late-drift | | dense-NVFP4 | 1.33× | identical | | MoE-Q4 | **1.75×** | coherent | | MoE-NVFP4 | 1.40× | coherent | Win **grows with prompt length** (dense-Q4: 1.06×@200 → ~1.58×@1577+; the chunked inverse is fixed-cost while per-step is O(T)). chunked prefill tok/s climbs to ~242 vs per-step's flat ~150. MoE wins more (GDN is a larger prefill fraction); NVFP4 less (dequant dominates its prefill). ## Numerical-stability fixes (each with a Mac regression test) 1. **Triangular inverse overflow.** `M = (I+A)⁻¹` by repeated squaring overflows f32 at `BT=64` (`N³² ≈ 4e57` before nilpotency zeroes it at `N⁶⁴`), producing garbage. Replaced with **row-iterative forward substitution** (FLA / vLLM `solve_tril`): `M[i,:] = eᵢ − A[i,:]·M` — no powers of A, stable for any `‖A‖`, serial depth independent of T. Test: `chunked_ops_stable_with_correlated_unit_norm_keys`. 2. **Gate underflow (MoE-only garbage).** `g_log = g.log()` round-trips through the exp-space gate; strong decay (which MoE has, dense doesn't) underflows `g` to 0 → `log(0) = -inf` → chunked `gcum_i − gcum_j = inf − inf = NaN`. Now compute `g_log = -exp(a_log)·softplus(a + dt_bias)` **directly in log-space** (matches the native `g_log` the fused Metal gating returns). Test: `compute_g_log_finite_under_strong_decay`. ## Validation - 6/6 `gated_delta` Rust unit tests, `cargo clippy -p mlx-core --all-targets`, `cargo fmt` — green. - Correctness validated on the DGX across **all four Qwen3.6 classes** (dense/MoE × Q4/NVFP4) — per-step vs chunked greedy A/B, coherent output everywhere (garbage only before the two fixes above). - Algorithm derivation in `docs/gdn-chunked-ops-spec.md`. ## Follow-ups (not blocking) - FLA 16-block row-iterative + block-merge inverse to cut short-prompt inverse depth 63→~16 (long-prompt asymptote is carry-bound, won't move). - Runtime non-finite guard → fall back to per-step (overflow is currently silent; the `Err` fallback doesn't catch `Inf`/`NaN`). - Possibly raise `CHUNK_THRESHOLD`→256 (the 200-tok win is only 1.06×). 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- CURSOR_SUMMARY --> --- > [!NOTE] > **Medium Risk** > Touches core Qwen3.5 inference recurrence and output numerics on CUDA, though Metal is gated off, per-step fallback exists, and behavior is covered by parity/stability tests. > > **Overview** > Adds a **default-on CUDA prefill fast path** for Qwen gated-delta (GDN): long, unmasked sequences on the non-Metal ops branch now run **`gated_delta_chunked_ops`**, a pure `MxArray` chunk-parallel port (BT=64) that replaces the O(T) per-token loop with O(T/BT) chunk carries and batched matmuls. Metal/`use_kernel=true` routing is unchanged; decode and masked calls still use per-step ops. **`MLX_GDN_KERNEL=perstep`** (and **`ForceChunkedOps`** / `chunked_ops` aliases) support same-binary A/B. > > Two **numerical fixes** ship with the chunked path: **`compute_g_log`** computes the decay gate in log-space (avoids `log(0)` → NaN on strong MoE decay), and **`invert_i_plus_strict_lower`** builds `(I+A)⁻¹` via forward substitution instead of f32 power squaring that overflows at BT=64. Chunked ops errors fall back to per-step with a stderr warning. > > Adds **`docs/gdn-chunked-ops-spec.md`** plus unit tests for env parsing, chunked vs per-step parity across chunk boundaries, correlated-key inverse stability, and strong-decay gating. > > <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit efe9c8f. Bugbot is set up for automated code reviews on this repo. Configure [here](https://www.cursor.com/dashboard/bugbot).</sup> <!-- /CURSOR_SUMMARY --> --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 87a59fc commit 74f981e

2 files changed

Lines changed: 443 additions & 1 deletion

File tree

0 commit comments

Comments
 (0)