Add gated_delta_update VJP — trainable GatedDeltaNet on Apple Silicon (fixes #482)#1168
Closed
SudarkinV wants to merge 6 commits intoml-explore:mainfrom
Closed
Add gated_delta_update VJP — trainable GatedDeltaNet on Apple Silicon (fixes #482)#1168SudarkinV wants to merge 6 commits intoml-explore:mainfrom
SudarkinV wants to merge 6 commits intoml-explore:mainfrom
Conversation
added 4 commits
April 19, 2026 19:36
- gated_delta_vjp.py: pure-Python reference VJP with mx.checkpoint. - gated_delta_vjp_metal.py: Metal kernel backward, 8-11x faster. - gated_delta_chunk_parallel.py: rank-C chunk-parallel variant. - qwen3_5.py / qwen3_next.py / kimi_linear.py: pick VJP in training, fall back to existing forward elsewhere. - tests/test_gated_delta_vjp.py: 12 pytest cases (FD, Metal, mask).
Power-iteration rank truncation at chunk boundary enables larger- model SFT on consumer Apple Silicon. 5x memory savings on top of chunked VJP; 2x speedup when combined with Metal backward. - gated_delta_vjp_compressed.py: pure-Python variant. - gated_delta_vjp_metal_compressed.py: Metal + compression. - gated_delta_rank_estimator.py: theorem-guided per-layer rank. - compress.py: public estimate_rank*() API. Activation via MLX_DELTANET_COMPRESS_* env vars.
- gated_delta_fused.py: compute_g + sigmoid + recurrence (3->1). - gated_delta_t1.py: mega-fused T=1 (5->1), 1.24x speedup. - gated_delta_factored_*.py: factored-state MSL kernel, 1.5x. - gated_delta_inference_compressed.py: int4/int8 cache, 4x memory. Auto-selected for T=1 inference in qwen3_5.py; opt-in via MLX_DELTANET_INFER_* env vars otherwise.
- gated_delta_prefix_scan.py: associative monoid formulation; associativity formally proven, numerically verified. - tree_speculative.py: binary (K=2) tree speculative decoding; working prototype, performance tuning as follow-up. - cache.py: KVCache.filter() and expand_batch() batch API needed for batched verifier.
5723422 to
350a9d6
Compare
Run pre-commit hooks (black 25.1.0, isort 6.0.0 --profile=black) on all added/modified Python files. No functional changes — only whitespace / import ordering / line-wrap style.
Author
Relationship to #496 and current
|
No code changes, comment/docstring rewording only.
Author
|
Closing to re-scope — will re-submit as a focused VJP-only PR when ready. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add
gated_delta_update_vjp— trainable GatedDeltaNet on Apple SiliconFixes Issue #482.
Why this is needed
gated_delta_updateis the recurrent step used by every hybrid-attentionQwen3.5 / Qwen3-Next / Kimi-Linear model in
mlx-lm. Its two existingimplementations both fail for training:
gated_delta_kernel— the fused Metal kernel used for inference —is registered as a
fast.metal_kernel, which has no VJP: any modelthat calls it during
.train()raisesValueError: [Primitive::vjp] Not implemented for CustomKernel.gated_delta_ops— the pure-Python fallback — unrolls therecurrence into an
O(T)-node autodiff graph and runs out of memoryat
T ≥ 2048on 36 GB Apple Silicon devices, which covers the commonLoRA fine-tuning setup.
So on an M-series Mac today, the linear-attention blocks of a Qwen3.5
model cannot be fine-tuned at all — the usual workaround is
mx.stop_gradient, which freezes 75% of the layers (everylinear_attnblock) and limits fine-tuning to the minority full-attention layers.
What this PR adds
A drop-in
gated_delta_update_vjp(q, k, v, a, b, A_log, dt_bias, state, mask=None)with identical shapes and semantics togated_delta_update, but with a registered VJP somx.grad,mx.value_and_grad,mlx_lm.loraand all other autodiff paths work.Two implementations ship:
gated_delta_vjp.py— pure-Python reference usingmx.checkpointon a chunked recurrence. No MLX internals, works on any backend that
supports
mx.checkpoint. Slow but verifiable.gated_delta_vjp_metal.py— Metal backward kernel with@mx.custom_function. Forward path re-uses the upstreamgated_delta_kernel; backward runs a custom MSL kernel that doesreverse sweep over saved state history with threadgroup-local
reduction (no atomic adds, fully deterministic).
Consumers in
qwen3_5.py,qwen3_next.py,kimi_linear.pypick up thenew function conditionally (
self.training and mask is None) and fallback to the existing forward in all other paths, so inference, KV-cache
decode, and evaluation are unchanged.
Numerical correctness
Numerical-gradient tests on a toy shape (B=1, T=8, Hk=2, Hv=4, Dk=16,
Dv=8, fp32):
Forward-only output matches
gated_delta_update(use_kernel=False)bit-exactly on both fp32 and bf16. Metal-backward gradients match the
Python reference to
< 2e-7(floating-point noise).Covered corner cases in the test suite:
T = 1,T = CHUNK_SIZE,T = CHUNK_SIZE + 1prefix only (as in the
opspath).A_log = 5,a = 15: noNaN/Infthanks to thesoftplus-argument clamp.Performance (Qwen3.5-9B GatedDeltaNet shape:
B=1, Hk=16, Hv=64, Dk=192, Dv=128, bf16)Compared with the Python VJP fallback, the Metal backward kernel is
8-11× faster and uses ~2× less memory. That is enough to make
Qwen3.5-27B-dense LoRA training fit at
max_seq=2048, andQwen3.5-35B-A3B-MoE LoRA training fit at
max_seq=1024on a 36 GBMac — neither was possible without this PR.
End-to-end training (Qwen3.5-9B LoRA, 50 iter,
max_seq=2048)Metal VJP matches the Python VJP loss curve iteration-by-iteration
(same seed, identical
Vallosses at every eval step) while beingcomparable in speed to the frozen-baseline workaround.
Running on
max_seq=4096with the full unfiltered train split — whichwas previously impossible — converges identically (
Val 0.433 → 0.200in 50 iter, peak 12.2 GB).
Design notes
Why a custom
.vjpinstead ofmx.checkpoint-wrapping thepython-ops path?
mx.checkpointalone still produces anO(T)autodiff graph inside the recurrence; it saves no tensors across
forward/backward but has to expand the whole loop as MLX primitives,
which is what OOMs at
T ≥ 2048. The custom kernel keeps the graph atO(T / CHUNK_SIZE)chunks — one primitive per 64 timesteps.Why chunked instead of single-shot? One
state_historybuffer fora full
T=2048sequence is ~400 MB at Qwen3.5-9B shapes (bf16). Thedefault
CHUNK_SIZE = 64bounds the peak to a single chunk at a timewhile keeping per-chunk Metal launch overhead negligible.
Why threadgroup-local reduction, not atomic? Atomic
atomic_fetch_addon Apple GPUs is relaxed-ordering, which producesnon-deterministic rounding over large reductions. In training the
accumulated noise degrades convergence measurably (a first prototype
that used atomics hit val-loss of 0.349 where the deterministic version
reaches 0.142 — i.e. back to the frozen-layer baseline). The current
kernel reduces across the four SIMD groups of each threadgroup via
shared memory and defers the final Python sum to a deterministic
axis-reduction.Models covered
Inference, KV-cache decode and evaluation on every path fall through
to the existing
gated_delta_update; there are no inference changes.Files
mlx_lm/models/gated_delta_vjp.py— pure-Python reference VJP withscalar and vectorised gating.
mlx_lm/models/gated_delta_vjp_metal.py— Metal backward + fastforward (via the upstream
gated_delta_kernel).mlx_lm/models/qwen3_5.py,qwen3_next.py,kimi_linear.py— pickthe VJP in training, fall back to the existing forward otherwise.
tests/test_gated_delta_vjp.py— 12 pytest cases covering:FD gradient check (all primal args), forward equivalence vs
reference (bit-exact), edge lengths (1, chunk, chunk+1), numerical
clamp under extreme inputs, mask handling (all-True == unmasked,
half-masked carry-over, FD through masked path), Metal equivalence
(forward + per-arg gradients, skipped if Metal unavailable).
Scope and follow-ups
Supported today:
g.ndim == 3) on both Python and Metal paths.g.ndim == 4) on the Python path (coversKimi-Linear out of the box).
that
fast.metal_kernelaccepts.Not in this PR — tracked as follow-ups:
mask-aware Metal backward kernel.on batched Python reference (2.6× forward speedup on
T=64in purePython, up to 6× on smaller chunks), but the MSL implementation
plus gradient derivation through the lower-triangular solve is a
research-grade follow-up on its own. See
CHUNK_PARALLEL_NOTES.mdandgated_delta_chunk_parallel_batched.pyfor the full derivation and PoC.
Reviewer notes
The Metal kernels are ~400 LoC of MSL with four small templated
parameters (
Dk,Dv,Hk,Hv) and are loaded via the existingmx.fast.metal_kernel. They reuse the same thread-grid layout as theupstream forward kernel so the diff is easy to follow.
All tests run locally on an M4 Max with MLX 0.31.1. Please point out
any naming / layout conventions that would fit better with the rest of
the
mlx-lmcodebase — happy to adjust.Post-initial-submission extensions
After the initial VJP work, several additional contributions landed in
the same branch. They sit alongside the core VJP change and can be
split into separate PRs if maintainers prefer.
Extension 1 — Compression-aware training
gated_delta_vjp_compressed.py: power-iteration rank-r truncation atevery chunk boundary. 5× additional memory savings on top of chunked
VJP. Enables SFT of 35B+ models on 36 GB Apple Silicon.
gated_delta_vjp_metal_compressed.py: Metal VJP backward combinedс rank truncation between chunks. 2× speedup over pure-Python
compressed (~10 sec/iter vs 31 sec/iter on Qwen3.5-9B), 30% less
peak memory (9.2 GB vs 12.8 GB). Same loss curve (correctness
preserved). Default backend when compression env vars set.
Full 500-iter training validation on Qwen3.5-9B, T=2048,
per-layer theorem-guided ranks (avg rank ~7, 6 MLX_DELTANET_COMPRESS_ITERS=3):
Peak memory stayed constant at 9.22 GB throughout the 500 iterations
— no leaks, no drift. Train loss steadily decreased from 0.32 → 0.06.
Val loss oscillates in [0.11, 0.29] due to small
val_batches=2(high variance) but the running mean stays around 0.17. Best
checkpoint is iter 100 (val 0.109); the final iter-500 checkpoint
remains production-quality.
Total time: 83 minutes on M4 Max (~10 sec/iter).
gated_delta_rank_estimator.py+mlx_lm.compresspublic module:theorem-guided rank selection. Users call
estimate_rank_per_layer()once, write JSON, point
MLX_DELTANET_COMPRESS_RANK_PER_LAYERat itfor 56% less training memory vs uniform rank=16.
Env vars:
Extension 2 — Inference kernel fusion
Three new MSL kernels, each is a drop-in replacement for the existing
forward path for T=1 inference:
gated_delta_fused.py: compute_g + sigmoid + recurrence (3→1). 1.13×kernel-level speedup, verified via FD.
gated_delta_t1.py: mega-fused T=1 — rms_norm(q/k) + inv_scale +compute_g + sigmoid + recurrence (5→1). 1.24× kernel-level,
correctness at bf16 machine precision.
gated_delta_factored_fixed.py: fixed-rank factored-state kernelwith round-robin slot replacement. 1.5× kernel-only, infrastructure
для future fused multi-layer work.
Auto-activated for T=1 inference in
qwen3_5.py(no env var required).Extension 3 — Inference memory
gated_delta_inference_compressed.py: int8/int4 DeltaNet cachequantization via
mx.quantize. 4× cache memory reduction, 3%compute overhead. Enables higher concurrency на Apple Silicon batch
serving.
Extension 4 — Tree speculative decoding (prototype)
tree_speculative.py+KVCache.filter()/.expand_batch()cacheAPI extensions. Binary tree (K=2) working prototype — correctness
PASS, performance tuning deferred. Draft cache rewind logic is the
remaining bottleneck (linear spec still faster on Qwen3.5 at accept
rate > 0.5). Framework infrastructure ready for future research.
Extension 5 — Associative prefix-scan scaffold
gated_delta_prefix_scan.py: formally proven associative monoidformulation of the recurrence. Equivalence vs sequential reference
verified at machine precision. Enables future distributed / O(log T)
scan implementations (currently falls back to sequential).
Extension 6 — Empirical grounding
The 1.24× fusion и ≤16 compression ranks are grounded in a proved
O(1) stable rank result for trained linear-attention state.
Replicated on 8 models × 3 architecture families (GatedDeltaNet,
Mamba-2, RWKV-7). Reproduction scripts attached в research
companion.
Benchmarks summary table