feat: add DFlash speculative decoding#78
Conversation
Critical bug fix and performance optimizations for DFlash speculative decoding. Acceptance rate improved from 25% to 89% (matching Python reference), throughput from 6.7 to 42 tok/s. Root cause: hiddenNorm was declared as without @ModuleInfo, so its RMSNorm weight was never loaded from safetensors. The key "hidden_norm.weight" didn't match the reflected key "hiddenNorm.weight", leaving the weight at all-ones instead of the trained values (~0.98). This single missing weight distorted every draft prediction, compounding through all 5 draft layers. Fix: Added @ModuleInfo(key: "hidden_norm") annotation, matching the safetensors key. Also added @ModuleInfo for norm and fc for consistency. Performance optimizations: - Streaming: replaced generateSync + buffered array with generateStreaming + Continuation, yielding tokens immediately - Draft prefetch: launch next cycle's draft with asyncEval before rollback, overlapping GPU work - Batched asyncEval: changed blocking eval() to asyncEval() for verify logits and hidden states - asyncEval(committedHidden): unblocks prefetch window - Stop token Set: precomputed O(1) lookup - Removed double fflush, added DFlashDumper call-site guards Submodule updates: - mlx-swift-lm: exactSmallProjPad for quantized linear at small seq_len (<16), DFlash protocols, open MambaCache/ArraysCache - mlx-swift: remove stale .air kernel files Benchmark (Qwen3.5-27B-4bit, thinking mode, 2048 tokens): 41.9 tok/s, 89.4% acceptance, 216 cycles
… streaming When SSD expert streaming is active, expert weight tensors (.weight) are replaced with zero-filled placeholders of the correct shape/dtype during loading. Only scales and biases are loaded into RAM — the actual expert weight data is read from SSD at runtime via pread/mmap. RAM savings for MoE models: - Qwen3.6-35B-A3B: 18.4 GB → 5.1 GB (73% reduction) - Expert weights skipped: 16.1 GB (weight only, not scales/biases) - Expert scales+biases loaded: ~2 GB (needed for dequantization) Performance on Qwen3.6-35B-A3B (512 tokens, math prompt): - No SSD streaming: 11.5 tok/s, 18.4 GB RAM - SSD streaming only: 11.5 tok/s, 5.1 GB RAM - SSD + DFlash: 32.2 tok/s, 5.1 GB RAM
Both streaming and non-streaming chat/text completion responses now include a 'timings' object with: - predicted_per_second: generation speed in tokens/second - predicted_n: number of completion tokens - predicted_ms: total generation wall-clock time in ms This matches llama-server's timing convention and allows clients to see generation speed directly from the API response without external measurement.
Tests 4 configurations for Qwen3.6-35B-A3B-4bit with same math prompt: - Baseline (no SSD, no DFlash) - SSD Streaming only - SSD Streaming + DFlash - DFlash only Results (512 tokens, 3 runs each): Baseline: 26.3 tok/s, 18.8 GB RAM SSD Streaming: 12.5 tok/s, 5.4 GB RAM SSD + DFlash: 33.3 tok/s, 7.4 GB RAM ← best tradeoff DFlash only: 125.4 tok/s, 20.0 GB RAM
- Add StreamableMoE conformance to Qwen3NextModelInner - Add LayerPartitionable conformance to Qwen3NextModelInner - Add DFlashTargetModel conformance to Qwen3NextModel - dflashEmbedTokens, dflashLmHeadLogits, dflashForwardWithCapture - dflashGatedDeltaForward with tape recording for GDN rollback - Add dflashForwardWithTape to Qwen3NextGatedDeltaNet - Add bridge file Qwen3Next+DFlash.swift - Short prompt works: 68.8% acceptance, 9.8 GB RAM (vs 45 GB full load) - Longer runs crash — likely Metal watchdog on 512-expert SSD reads
…el cache Replace if-branch masking with metal::select for zero warp-divergence state updates. Reorganize KernelCache from 8 flat named vars to tapeReplay[vec][msk] and gatedDeltaTape[vec][msk] 2D arrays. Simplify dispatch call sites to one-liner index lookups. Minor whitespace cleanup in DFlashIntermediateDumper.
… property Add MambaSnapshotCache: lightweight O(1) snapshot-based rollback (lazy reference capture, no GPU copy) as an alternative to RecurrentRollbackCache's innovation-tape replay. Add dflashUseTapeRollback Bool to DFlashTargetModel (default true) so models can opt in to either strategy. Update makeTargetCache and arm/rollback helpers with clearer comments. Also switch RecurrentRollbackCache.armRollback to lazy reference capture (removes unnecessary MLX.contiguous copies on arm path).
Add DFlashKernelBench executable for isolated kernel timing. Exclude DFlashKernelsOptimized.swift from the DFlash library target (work-in-progress alternative kernel implementations kept for reference).
…next.sh bench_35b.sh: save per-run raw response JSON, extract structured results into bench_results.json (tok/s, RAM, timing per config) for downstream tooling. Use slug variable consistently for log file naming. Add bench_coder_next.sh for benchmarking Qwen3-Coder-Next model variants.
Move comparison tests from tests/DFlashComparison/ to tests/DFlash/, adding DFlashBenchmark.swift, DFlashProfiler.swift, updated cosine similarity comparison tools, and a README. Update .gitignore intermediates path.
…tension DFlash protocol methods (dflashEmbedTokens, dflashLmHeadLogits, dflashForwardWithCapture, dflashIsHybridGDN) moved from Qwen3Next.swift into Sources/SwiftLM/Qwen3Next+DFlash.swift, matching the pattern used by Qwen35+DFlash.swift. Requires mlx-swift-lm commit a707519 (3 public access modifier additions).
There was a problem hiding this comment.
Pull request overview
Adds a new DFlash speculative decoding runtime and integrates it into the SwiftLM server, along with kernel benchmarking and cross-language intermediate-dump tooling to validate numerical parity.
Changes:
- Introduces
Sources/DFlash/runtime (engines, draft model/backend, rollback caches, kernels, registry, dumper) and aDFlashlibrary product. - Integrates DFlash into
SwiftLMvia--dflashand draft auto-resolution/loading; adds timing fields to responses and new benchmark scripts. - Adds kernel micro-benchmark target (
DFlashKernelBench) and a reorganized DFlash test/tooling suite undertests/DFlash/.
Reviewed changes
Copilot reviewed 25 out of 26 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/DFlash/dump_python_intermediates.py | Python-side reference dump of intermediates for Swift↔Python comparison. |
| tests/DFlash/compare_swift_python.py | Compares Swift dumps vs Python dumps using cosine similarity. |
| tests/DFlash/compare_cosine.py | Self-consistency and “Swift-equivalent” Python path comparison tooling. |
| tests/DFlash/README.md | Documentation for DFlash benchmarking/comparison tools. |
| tests/DFlash/DFlashProfiler.swift | Swift profiler for kernel performance and basic correctness checks. |
| tests/DFlash/DFlashCosSimComparison.swift | Swift tool to load .npy and compare/inspect intermediates (partial WIP). |
| tests/DFlash/DFlashBenchmark.swift | End-to-end benchmark harness for baseline vs DFlash performance (tooling). |
| bench_coder_next.sh | Benchmark script for Qwen3-Coder-Next across baseline/SSD/DFlash configs. |
| bench_35b.sh | Benchmark script for 35B model; adds rich JSON export for downstream tooling. |
| Sources/SwiftLM/Server.swift | Adds --dflash path, draft auto-resolution/loading, and timing fields in responses. |
| Sources/SwiftLM/Qwen3Next+DFlash.swift | Adds DFlashTargetModel conformance for Qwen3NextModel. |
| Sources/SwiftLM/Qwen35+DFlash.swift | Adds DFlashTargetModel conformance for Qwen35 models. |
| Sources/DFlashKernelBench/main.swift | New micro-benchmark executable for DFlash Metal kernels (trace-friendly). |
| Sources/DFlash/RecurrentRollbackCache.swift | Adds recurrent tape rollback cache + snapshot rollback alternative. |
| Sources/DFlash/DFlashRuntime.swift | Core DFlash generation loop + cache management + token utilities. |
| Sources/DFlash/DFlashKernelsOptimized.swift | Alternative optimized kernels implementation (currently excluded from build). |
| Sources/DFlash/DFlashKernels.swift | Main kernel implementations (tape replay, gated-delta+tape, SDPA 2-pass). |
| Sources/DFlash/DFlashIntermediateDumper.swift | Writes Swift intermediates to .npy for Python tooling. |
| Sources/DFlash/DFlashEngine.swift | Defines FullAttentionEngine and HybridGDNEngine rollback behavior. |
| Sources/DFlash/DFlashDraftRegistry.swift | Maps target model refs to draft model refs for auto-resolution. |
| Sources/DFlash/DFlashDraftModel.swift | Implements the DFlash draft model architecture and context feature extraction. |
| Sources/DFlash/DFlashDraftBackend.swift | Implements greedy drafting logic using target embed/lm_head and draft model. |
| Package.swift | Adds DFlash library product and DFlashKernelBench executable target. |
| .gitignore | Ignores generated intermediate dump directory. |
Comments suppressed due to low confidence (1)
Sources/SwiftLM/Server.swift:1013
- This
extension ModelContainerblock is empty (only doc comments), so it adds no functionality and can confuse readers about missing API. Either implement the intended helper (e.g.extractDFlashTargetModel()) or remove the empty extension.
}
// Use the most recently modified snapshot
let sorted = snapshots
.filter { (try? $0.resourceValues(forKeys: [.isDirectoryKey]).isDirectory) == true }
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- DFlashBenchmark: fix operator-precedence bug in memoryGB calculation - DFlashBenchmark: replace unsafe `as! DFlashTargetModel` with guarded cast + exit - DFlashBenchmark: replace NSNumber-casting median with BinaryFloatingPoint/BinaryInteger overloads - DFlashRuntime: wrap generateStreaming in Task inside AsyncStream to avoid blocking caller - DFlashRuntime: fix first-token duplication — skip append (not just yield) for already-emitted token - DFlashRuntime: replace O(vocabSize*n) suppress-mask broadcast with O(vocabSize) scatter - DFlashIntermediateDumper: fix .npy header — spec-compliant shape tuples and newline-as-final-byte - Server: remove dead speculative-decoding branch that logged but passed no draft model
|
@solderzzc thanks for the review, will post benchmarks from m3 ultra soon |
# Conflicts: # README.md
@0xClandestine Thanks for your PR, I'm working on setting up Github Action for test automation. And collecting benchmarks on my M5 Pro 64GB. |
Prompt cache save/restore was incorrectly applied to Qwen3Next which uses a hybrid KVCache+MambaCache architecture. MambaCache RNN states cannot be arbitrarily trimmed or replayed at arbitrary token boundaries unlike KVCacheSimple, so attempting to restore a partial match would corrupt the linear attention state and cause spurious 1-token outputs. Fix: PromptCache.save() and PromptCache.restore() now skip immediately if any layer in the cache is a MambaCache instance. Also fixes run_benchmark.sh Test 0 (automated matrix) to pass MODEL via environment variable instead of feeding it through stdin, so the model selection prompt is correctly bypassed when MODEL is pre-set.
Replacing the stdin pipe approach with an env var so child invocations from Test 0's automated matrix loop skip the interactive menu entirely. The previous echo-pipe was consumed by the 'read suite_opt' prompt but any subsequent reads (model selection) had no input, causing the script to fall through to option 3 by default.
When SUITE_OPT is set (automated matrix mode), skip all menu echoes and the read prompt entirely. Child processes now run silently with only test-relevant output.
Both test-speculative.sh and test-dflash.sh grep for 'Using speculative decoding' in the server log to confirm the speculative path was activated. This string was never emitted — the tests were checking a log line that didn't exist, causing speculative-decoding and dflash-speculative-decoding CI jobs to always fail on Test 1. Fix: emit the exact expected log line: - Standard spec: after draft model is loaded successfully - DFlash spec: at generation dispatch in Server.swift Server log now contains all strings the tests grep for: ✅ 'Draft model loaded successfully' ✅ 'Using speculative decoding' ✅ 'speculative decoding' (for test-speculative-eval.sh)
test-dflash.sh grepped for:
1. 'Draft model loaded successfully' — only emitted by standard draft path,
not DFlash path which has its own 'DFlash draft model loaded' message
2. 'Using speculative decoding' — not emitted by DFlash path at all
3. 'speculative decoding' — was present but test was failing on (1)
Add both required lines immediately after DFlash draft model weights load,
mirroring the standard speculative decoding path. The streaming failures
('missing [DONE] sentinel') were downstream of the model-not-found state
caused by the load log mismatch, not an inference bug.
d9f824b to
4c042a6
Compare
Adds Sources/SwiftLM/{Qwen3,Qwen3MoE,Llama}+DFlash.swift — each
declares the DFlashTargetModel protocol conformance and delegates to
the model's public callCapturing / embedTokens / lmHead
(now on *ModelInner via mlx-swift-lm b453).
Coverage:
Qwen3Model → Qwen3-8B and similar dense Qwen3 variants
Qwen3MoEModel → Qwen3-Coder-30B-A3B and other Qwen3 MoE variants
LlamaModel → Meta-Llama-3.x, Mistral, and Llama-family models
Qwen35MoEModel → already covered via Qwen35Model inheritance
Qwen36MoE → no separate Swift class found; uses Qwen35MoE path
Co-authored-by: clandestine.eth <96172957+0xClandestine@users.noreply.github.com>
Gemma4 omni (5.2GB) on a 7.5GB runner is tight. After other CI jobs have run and filled the model cache, available RAM can drop below the threshold needed for stable Metal command buffer execution, causing sporadic GPU timeout crashes (kIOGPUCommandBufferCallbackErrorTimeout). Add a vm_stat-based preflight check: if available+inactive RAM < 2.5GB, exit 0 (skip) instead of crashing the whole run.
This reverts commit 9fc993c.
Own DeepSeek V3 (deepseek_v3 / kimi_k25) and Kimi Linear (kimi_linear) model implementations directly in SwiftLM so DFlashTargetModel conformance is available without any upstream submodule changes. - DeepseekV3DFlash.swift: full DSV3Config + model with callCapturing - KimiLinearDFlash.swift: hybrid KDA/MLA Kimi 2.6 model with DFlash - DFlashModelRegistry.swift: registers all three model types via LLMTypeRegistry.shared.registerModelType() at startup - Server.swift: call registerDFlashModelTypes() before model loading
Use @ModuleInfo(key: "model") on the inner model property so weights at model.* paths are found correctly. Also use @ModuleInfo(key: "norm") for norm layers initialized in init() so their weights are tracked.
🐛 Bugs & Required Changes —
|
| File | Change | Severity |
|---|---|---|
DeepseekV3.swift |
Strip language_model. prefix in sanitize() |
Blocker |
Module.swift |
Handle .none entries in update(modules:) |
Blocker |
Load.swift |
verify: [] for mixed-precision models |
Blocker |
| swift-transformers | Tiktoken support / proper BPE decoder | Blocker for tiktoken models |
| Server launch | FD limit ulimit -n 1024+ |
Runtime crash |
| Draft model loader | Handle sharded safetensors for draft models | DFlash broken |
DeepseekV3.swift |
Add DFlashTargetModel conformance |
DFlash no-op |
pread_into C++ |
3rd-request segfault (use-after-free?) | Runtime crash |
Happy to open separate issues for any of these or submit patches. The branch is very close to working — most of these are edge cases around MoE models with mixed precision and tiktoken tokenizers rather than fundamental architecture issues.
|
@hankbobtheresearchoor good bot 🤖 |
… limit DeepseekV3DFlash.sanitize(): - Strip 'language_model.' wrapper prefix present in kimi_k25 and some other HuggingFace exports so weight keys resolve to model.* paths - After stacking per-expert weights into switch_mlp, remove the original experts.N.* keys to prevent verify: .noUnusedKeys crash - Generalize layer filter to use numHiddenLayers instead of hardcoded 61 Server.run(): - Raise RLIMIT_NOFILE to 4096 at startup; large sharded models (kimi_k25 has 182 safetensor shards) exhaust the default macOS limit of 256
|
Thanks for the detailed writeup bot, very useful. Here's where things stand after today's commits: Fixed in this branch:
Needs an upstream fix (mlx-swift): Issue 2 — the sparse Still out of scope here: tiktoken (Issue 4), 3rd-request SSD segfault (Issue 8 — tracked separately), Metal toolchain preflight (Issue 9). |
…prevent GPU timeouts
- Move MLX_MAX_OPS_PER_BUFFER=50 to top of run() before Metal init - Enable --stream-experts automatically on <12GB machines in test-dflash.sh so weights are paged via mmap/pread instead of macOS VM swap - Auto-cap draft tokens to 1 under SSD streaming (minimal fan-out) - Always compute draftFootprintBytes regardless of --stream-experts flag
Merges ericjlake's prompt-cache fixes from PR SharpAI#85, resolving conflicts with the DFlash integration (PR SharpAI#78). Changes from ericjlake: - MambaCache safety gate + KVCacheSimple T-dim slice in save() - ndim >= 3 guard in minCachedSeqLen scan - Spec-decode short-circuit ordering (check before cache restore) - README: Qwen3-A3B full-RAM perf table (M1 Ultra 64 GB) Conflict resolution: - README.md: kept both Qwen3-A3B and DeepSeek-V4 perf tables - Server.swift save(): kept existing MambaCache early return + new T-dim slice - Server.swift decision branch: combined spec-decode-first + skipPromptCache (kvBits) Closes SharpAI#84. Co-authored-by: Eric Lake <ericjlake@users.noreply.github.com>
Summary
Sources/DFlash/) withFullAttentionEngineandHybridGDNEngineDFlashTargetModelconformance forQwen3NextModel(27B) andQwen35Model(35B)MambaSnapshotCacheas a lightweight O(1) rollback alternative toRecurrentRollbackCache, selectable viadflashUseTapeRollbackmetal::select, 2D kernel cache arrayDFlashKernelBenchmicro-benchmark targettests/DFlash/bench_35b.sh; addsbench_coder_next.shSubmodule dependency
Requires a 3-line change in
SharpAI/mlx-swift-lm(commita707519): addspublictoQwen3NextModelInner.embedTokens,Qwen3NextModel.lmHead, andQwen3NextModelInner.callCapturing. This allowsDFlashTargetModelconformance to live inSources/SwiftLM/Qwen3Next+DFlash.swift(same pattern asQwen35+DFlash.swift) rather than inside the submodule.The submodule pointer in this PR references that commit on
SharpAI/mlx-swift-lm.Test plan
swift build -c releasesucceeds./SwiftLM --model <path> --port 5414./SwiftLM --model <27B> --draft-model <draft> --dflash --port 5414./SwiftLM --model <35B> --stream-experts --dflash --draft-model <draft> --port 5414bench_35b.shand verify tok/s matches benchmarks above