Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,21 @@ Benchmark results for `gemma-4-26b-a4b-it-4bit` (26B MoE, 4-bit) on M5 Pro 64 GB

> Run `./run_benchmark.sh` to generate these metrics on your own device. (See **Benchmarks & Testing** below).

### Qwen3.6-35B-A3B-UD-MLX-4bit (Full-RAM) — M1 Ultra 64 GB

Benchmark results for full-RAM (no SSD streaming) MoE inference on M1 Ultra. The 3.4× vanilla improvement vs. earlier builds comes from the `needsMoeFlush` gate in `mlx-swift-lm` (see [SwiftLM #84](https://github.com/SharpAI/SwiftLM/issues/84)) — the per-layer GPU sync barrier required for SSD streaming was firing unconditionally on the full-RAM path and flushing MLX's kernel-batching pipeline.

| Configuration | Short (~126 tok) | Medium (~400 tok) | Long (~800 tok) |
|---|---|---|---|
| **Vanilla full-GPU** | **61.7 tok/s** | **62.3 tok/s** | **62.1 tok/s** |
| `--dflash` (block_size=16) † | 52.3 tok/s | **70.3 tok/s** (+13%) | **69.9 tok/s** (+13%) |

> *Hardware:* Apple M1 Ultra, 64 GB unified memory, macOS 26.x. Model ~20 GB on disk, ~21.6 GB resident weight + ~2.1 GB KV at runtime.
> *Flags:* `--repeat-penalty 1.1 --max-tokens 2000`, `temperature: 0.6`, single-stream `/v1/chat/completions`.
> *Vanilla baseline before* `needsMoeFlush` *gate (for reference):* 19.2 / 18.1 / 18.3 tok/s — see #84.

† DFlash uses [`z-lab/Qwen3.6-35B-A3B-DFlash`](https://huggingface.co/z-lab/Qwen3.6-35B-A3B-DFlash) (~948 MB) as the block-diffusion draft model. DFlash gives a clean +13% on medium/long generations but regresses short prompts (block overhead doesn't amortize at low token counts) and changes stop-condition behavior (`finish_reason=null` vs `stop`/`length`). Recommend a quality eval before using as default.

### DeepSeek-V4-Flash (126 GB, Q3-mixed-gs128-affine) — M5 Pro 64 GB

Model: [`Thump604/DeepSeek-V4-Flash-MLX-Q3-mixed-gs128-affine`](https://huggingface.co/Thump604/DeepSeek-V4-Flash-MLX-Q3-mixed-gs128-affine)
Expand Down
54 changes: 49 additions & 5 deletions Sources/SwiftLM/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,24 @@ actor PromptCache {
if cache.contains(where: { $0 is MambaCache }) {
return
}
let states = cache.map { $0.state }
let P = tokens.count
// For attention KVCacheSimple layers, the state tensor is [B, H, T, D] with a
// pre-allocated T that can exceed the actual prompt length P. If we store the
// full over-sized buffer, restore()'s trim() by (cached.tokens.count - matchLen)
// still leaves T - P slots of garbage beyond the valid prefix. Slice T to P at
// save time so cached.tokens.count === cached state's T.
let states: [[MLXArray]] = cache.map { layer -> [MLXArray] in
let s = layer.state
if layer is KVCacheSimple {
return s.map { arr -> MLXArray in
guard arr.ndim >= 3 else { return arr }
let T = arr.dim(2)
if T > P { return arr[.ellipsis, ..<P, 0...] }
return arr
}
}
return s
}
let metaStates = cache.map { $0.metaState }
// Materialize all lazy MLX arrays so they survive cache mutations
let allArrays = states.flatMap { $0 }
Expand All @@ -1206,6 +1223,20 @@ actor PromptCache {
misses += 1
return nil
}
// ── Recurrent-layer safety gate ──
// MambaCache (and other recurrent caches) store a 2-D hidden state with no
// T dimension, so the dim(2) read below would crash. Hybrid Mamba/attention
// models (Qwen-Next, Mamba-2, etc.) can't be safely prefix-restored because
// the recurrent hidden state was computed over the WHOLE previous sequence
// and there is no trim(excess) operator for it. Treat any cache containing
// a recurrent layer as a miss before we touch anything.
let hasRecurrentLayer = cache.contains { layer in
!(layer is KVCacheSimple) && !(String(describing: type(of: layer)).contains("Rotating"))
}
if hasRecurrentLayer {
misses += 1
return nil
}
// Token-by-token longest common prefix scan
var matchLen = 0
for (a, b) in zip(cached.tokens, newTokens) {
Expand All @@ -1226,6 +1257,7 @@ actor PromptCache {
// dim(2) = T = the number of cached tokens for that layer.
let minCachedSeqLen = cached.states.map { arrays -> Int in
guard let firstArray = arrays.first else { return 0 }
guard firstArray.ndim >= 3 else { return 0 }
return firstArray.dim(2) // T dimension
}.min() ?? 0
if excess >= minCachedSeqLen {
Expand Down Expand Up @@ -1520,13 +1552,25 @@ func handleChatCompletion(
// raw <|image|>/<|audio|> token embeddings instead of the projected features.
let isMultimodalRequest = lmInput.image != nil || lmInput.audio != nil

// Try to restore via token-by-token prefix match (llama-server style).
// Skip for quantized-KV requests: the prompt cache stores KV state produced
// with KVCacheSimple; restoring it into a QuantizedKVCache (or vice-versa)
// ── Decision branch ──
// Speculative decoding is CHECKED FIRST because a cache-hit rollback
// corrupts the draft model's KV state (draft and main model cycle tokens
// in lock-step). We'd rather pay the prefill than emit garbage.
//
// Skip prompt cache for quantized-KV requests: the prompt cache stores KV state
// produced with KVCacheSimple; restoring it into a QuantizedKVCache (or vice-versa)
// is unsafe and produces incorrect results or runtime failures.
let skipPromptCache = isMultimodalRequest || params.kvBits != nil
var stream: AsyncStream<Generation>
if !skipPromptCache, let cachedCount = await promptCache.restore(newTokens: promptTokens, into: cache) {
if let draftRef = draftModelRef {
// Speculative decoding path: draft model generates candidates, main model verifies.
// Bypass prompt cache to avoid draft/main KV drift on partial-match restores.
print("[SwiftLM] Using speculative decoding (\(numDraftTokens) draft tokens/round)")
stream = try MLXLMCommon.generate(
input: lmInput, cache: cache, parameters: params, context: context,
draftModel: draftRef.model, numDraftTokens: numDraftTokens
)
} else if !skipPromptCache, let cachedCount = await promptCache.restore(newTokens: promptTokens, into: cache) {
// Cache hit: KV state is pre-populated up to cachedCount tokens.
// Only compute the remaining (new) tokens.
var startIndex = cachedCount
Expand Down
Loading