Commit 74f981e
perf(qwen3.5): CUDA GDN chunked-ops prefill (~1.3–1.75× TTFT, all Qwen3.6 classes) (mlx-node#72)
## Summary
Collapses the **O(T) per-step gated-delta (GDN) recurrence** on the CUDA
(non-Metal) prefill path into **O(T/BT) chunk-serial batched matmuls**
(cuBLAS / tensor cores) — a device-agnostic, pure-`MxArray` port of the
in-tree Metal chunked kernel
(`crates/mlx-sys/src/metal/gated_delta_chunked.metal.inc`).
- **Zero build changes** (no nvcc/NVRTC) — matmuls route through cuBLAS.
- **Default-on** for the CUDA ops path; `MLX_GDN_KERNEL=perstep` reverts
for same-binary A/B.
- **No Metal impact** — `use_kernel=true` never reaches this path; the
Mac/Metal production path is byte-identical.
This attacks the *"GDN per-step recurrence is the prefill floor"*
bottleneck that [PR mlx-node#71](mlx-node#71
benchmark flagged as the **#1 dense lever**.
## Measured (GB10 / DGX Spark, Qwen3.6, warm, prefill TTFT vs per-step)
| model | 1577-tok speedup | parity |
|--------------|------------------|--------|
| dense-Q4 | **1.62×** | byte-identical / late-drift |
| dense-NVFP4 | 1.33× | identical |
| MoE-Q4 | **1.75×** | coherent |
| MoE-NVFP4 | 1.40× | coherent |
Win **grows with prompt length** (dense-Q4: 1.06×@200 → ~1.58×@1577+;
the chunked inverse is fixed-cost while per-step is O(T)). chunked
prefill tok/s climbs to ~242 vs per-step's flat ~150. MoE wins more (GDN
is a larger prefill fraction); NVFP4 less (dequant dominates its
prefill).
## Numerical-stability fixes (each with a Mac regression test)
1. **Triangular inverse overflow.** `M = (I+A)⁻¹` by repeated squaring
overflows f32 at `BT=64` (`N³² ≈ 4e57` before nilpotency zeroes it at
`N⁶⁴`), producing garbage. Replaced with **row-iterative forward
substitution** (FLA / vLLM `solve_tril`): `M[i,:] = eᵢ − A[i,:]·M` — no
powers of A, stable for any `‖A‖`, serial depth independent of T. Test:
`chunked_ops_stable_with_correlated_unit_norm_keys`.
2. **Gate underflow (MoE-only garbage).** `g_log = g.log()` round-trips
through the exp-space gate; strong decay (which MoE has, dense doesn't)
underflows `g` to 0 → `log(0) = -inf` → chunked `gcum_i − gcum_j = inf −
inf = NaN`. Now compute `g_log = -exp(a_log)·softplus(a + dt_bias)`
**directly in log-space** (matches the native `g_log` the fused Metal
gating returns). Test: `compute_g_log_finite_under_strong_decay`.
## Validation
- 6/6 `gated_delta` Rust unit tests, `cargo clippy -p mlx-core
--all-targets`, `cargo fmt` — green.
- Correctness validated on the DGX across **all four Qwen3.6 classes**
(dense/MoE × Q4/NVFP4) — per-step vs chunked greedy A/B, coherent output
everywhere (garbage only before the two fixes above).
- Algorithm derivation in `docs/gdn-chunked-ops-spec.md`.
## Follow-ups (not blocking)
- FLA 16-block row-iterative + block-merge inverse to cut short-prompt
inverse depth 63→~16 (long-prompt asymptote is carry-bound, won't move).
- Runtime non-finite guard → fall back to per-step (overflow is
currently silent; the `Err` fallback doesn't catch `Inf`/`NaN`).
- Possibly raise `CHUNK_THRESHOLD`→256 (the 200-tok win is only 1.06×).
🤖 Generated with [Claude Code](https://claude.com/claude-code)
<!-- CURSOR_SUMMARY -->
---
> [!NOTE]
> **Medium Risk**
> Touches core Qwen3.5 inference recurrence and output numerics on CUDA,
though Metal is gated off, per-step fallback exists, and behavior is
covered by parity/stability tests.
>
> **Overview**
> Adds a **default-on CUDA prefill fast path** for Qwen gated-delta
(GDN): long, unmasked sequences on the non-Metal ops branch now run
**`gated_delta_chunked_ops`**, a pure `MxArray` chunk-parallel port
(BT=64) that replaces the O(T) per-token loop with O(T/BT) chunk carries
and batched matmuls. Metal/`use_kernel=true` routing is unchanged;
decode and masked calls still use per-step ops.
**`MLX_GDN_KERNEL=perstep`** (and **`ForceChunkedOps`** / `chunked_ops`
aliases) support same-binary A/B.
>
> Two **numerical fixes** ship with the chunked path:
**`compute_g_log`** computes the decay gate in log-space (avoids
`log(0)` → NaN on strong MoE decay), and
**`invert_i_plus_strict_lower`** builds `(I+A)⁻¹` via forward
substitution instead of f32 power squaring that overflows at BT=64.
Chunked ops errors fall back to per-step with a stderr warning.
>
> Adds **`docs/gdn-chunked-ops-spec.md`** plus unit tests for env
parsing, chunked vs per-step parity across chunk boundaries,
correlated-key inverse stability, and strong-decay gating.
>
> <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit
efe9c8f. Bugbot is set up for automated
code reviews on this repo. Configure
[here](https://www.cursor.com/dashboard/bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
---------
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>1 parent 87a59fc commit 74f981e
2 files changed
Lines changed: 443 additions & 1 deletion
0 commit comments