Skip to content

Commit 7df2170

Browse files
authored
fix(server): prompt-cache bleed fixes + Qwen3-A3B perf table (#85)
fix(server): prompt-cache bleed fixes — MambaCache gate + ndim guard + spec-decode ordering
2 parents 29f3816 + 53b040d commit 7df2170

2 files changed

Lines changed: 64 additions & 5 deletions

File tree

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,21 @@ Benchmark results for `gemma-4-26b-a4b-it-4bit` (26B MoE, 4-bit) on M5 Pro 64 GB
7373

7474
> Run `./run_benchmark.sh` to generate these metrics on your own device. (See **Benchmarks & Testing** below).
7575
76+
### Qwen3.6-35B-A3B-UD-MLX-4bit (Full-RAM) — M1 Ultra 64 GB
77+
78+
Benchmark results for full-RAM (no SSD streaming) MoE inference on M1 Ultra. The 3.4× vanilla improvement vs. earlier builds comes from the `needsMoeFlush` gate in `mlx-swift-lm` (see [SwiftLM #84](https://github.com/SharpAI/SwiftLM/issues/84)) — the per-layer GPU sync barrier required for SSD streaming was firing unconditionally on the full-RAM path and flushing MLX's kernel-batching pipeline.
79+
80+
| Configuration | Short (~126 tok) | Medium (~400 tok) | Long (~800 tok) |
81+
|---|---|---|---|
82+
| **Vanilla full-GPU** | **61.7 tok/s** | **62.3 tok/s** | **62.1 tok/s** |
83+
| `--dflash` (block_size=16) † | 52.3 tok/s | **70.3 tok/s** (+13%) | **69.9 tok/s** (+13%) |
84+
85+
> *Hardware:* Apple M1 Ultra, 64 GB unified memory, macOS 26.x. Model ~20 GB on disk, ~21.6 GB resident weight + ~2.1 GB KV at runtime.
86+
> *Flags:* `--repeat-penalty 1.1 --max-tokens 2000`, `temperature: 0.6`, single-stream `/v1/chat/completions`.
87+
> *Vanilla baseline before* `needsMoeFlush` *gate (for reference):* 19.2 / 18.1 / 18.3 tok/s — see #84.
88+
89+
† DFlash uses [`z-lab/Qwen3.6-35B-A3B-DFlash`](https://huggingface.co/z-lab/Qwen3.6-35B-A3B-DFlash) (~948 MB) as the block-diffusion draft model. DFlash gives a clean +13% on medium/long generations but regresses short prompts (block overhead doesn't amortize at low token counts) and changes stop-condition behavior (`finish_reason=null` vs `stop`/`length`). Recommend a quality eval before using as default.
90+
7691
### DeepSeek-V4-Flash (126 GB, Q3-mixed-gs128-affine) — M5 Pro 64 GB
7792

7893
Model: [`Thump604/DeepSeek-V4-Flash-MLX-Q3-mixed-gs128-affine`](https://huggingface.co/Thump604/DeepSeek-V4-Flash-MLX-Q3-mixed-gs128-affine)

Sources/SwiftLM/Server.swift

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,24 @@ actor PromptCache {
11801180
if cache.contains(where: { $0 is MambaCache }) {
11811181
return
11821182
}
1183-
let states = cache.map { $0.state }
1183+
let P = tokens.count
1184+
// For attention KVCacheSimple layers, the state tensor is [B, H, T, D] with a
1185+
// pre-allocated T that can exceed the actual prompt length P. If we store the
1186+
// full over-sized buffer, restore()'s trim() by (cached.tokens.count - matchLen)
1187+
// still leaves T - P slots of garbage beyond the valid prefix. Slice T to P at
1188+
// save time so cached.tokens.count === cached state's T.
1189+
let states: [[MLXArray]] = cache.map { layer -> [MLXArray] in
1190+
let s = layer.state
1191+
if layer is KVCacheSimple {
1192+
return s.map { arr -> MLXArray in
1193+
guard arr.ndim >= 3 else { return arr }
1194+
let T = arr.dim(2)
1195+
if T > P { return arr[.ellipsis, ..<P, 0...] }
1196+
return arr
1197+
}
1198+
}
1199+
return s
1200+
}
11841201
let metaStates = cache.map { $0.metaState }
11851202
// Materialize all lazy MLX arrays so they survive cache mutations
11861203
let allArrays = states.flatMap { $0 }
@@ -1206,6 +1223,20 @@ actor PromptCache {
12061223
misses += 1
12071224
return nil
12081225
}
1226+
// ── Recurrent-layer safety gate ──
1227+
// MambaCache (and other recurrent caches) store a 2-D hidden state with no
1228+
// T dimension, so the dim(2) read below would crash. Hybrid Mamba/attention
1229+
// models (Qwen-Next, Mamba-2, etc.) can't be safely prefix-restored because
1230+
// the recurrent hidden state was computed over the WHOLE previous sequence
1231+
// and there is no trim(excess) operator for it. Treat any cache containing
1232+
// a recurrent layer as a miss before we touch anything.
1233+
let hasRecurrentLayer = cache.contains { layer in
1234+
!(layer is KVCacheSimple) && !(String(describing: type(of: layer)).contains("Rotating"))
1235+
}
1236+
if hasRecurrentLayer {
1237+
misses += 1
1238+
return nil
1239+
}
12091240
// Token-by-token longest common prefix scan
12101241
var matchLen = 0
12111242
for (a, b) in zip(cached.tokens, newTokens) {
@@ -1226,6 +1257,7 @@ actor PromptCache {
12261257
// dim(2) = T = the number of cached tokens for that layer.
12271258
let minCachedSeqLen = cached.states.map { arrays -> Int in
12281259
guard let firstArray = arrays.first else { return 0 }
1260+
guard firstArray.ndim >= 3 else { return 0 }
12291261
return firstArray.dim(2) // T dimension
12301262
}.min() ?? 0
12311263
if excess >= minCachedSeqLen {
@@ -1520,13 +1552,25 @@ func handleChatCompletion(
15201552
// raw <|image|>/<|audio|> token embeddings instead of the projected features.
15211553
let isMultimodalRequest = lmInput.image != nil || lmInput.audio != nil
15221554

1523-
// Try to restore via token-by-token prefix match (llama-server style).
1524-
// Skip for quantized-KV requests: the prompt cache stores KV state produced
1525-
// with KVCacheSimple; restoring it into a QuantizedKVCache (or vice-versa)
1555+
// ── Decision branch ──
1556+
// Speculative decoding is CHECKED FIRST because a cache-hit rollback
1557+
// corrupts the draft model's KV state (draft and main model cycle tokens
1558+
// in lock-step). We'd rather pay the prefill than emit garbage.
1559+
//
1560+
// Skip prompt cache for quantized-KV requests: the prompt cache stores KV state
1561+
// produced with KVCacheSimple; restoring it into a QuantizedKVCache (or vice-versa)
15261562
// is unsafe and produces incorrect results or runtime failures.
15271563
let skipPromptCache = isMultimodalRequest || params.kvBits != nil
15281564
var stream: AsyncStream<Generation>
1529-
if !skipPromptCache, let cachedCount = await promptCache.restore(newTokens: promptTokens, into: cache) {
1565+
if let draftRef = draftModelRef {
1566+
// Speculative decoding path: draft model generates candidates, main model verifies.
1567+
// Bypass prompt cache to avoid draft/main KV drift on partial-match restores.
1568+
print("[SwiftLM] Using speculative decoding (\(numDraftTokens) draft tokens/round)")
1569+
stream = try MLXLMCommon.generate(
1570+
input: lmInput, cache: cache, parameters: params, context: context,
1571+
draftModel: draftRef.model, numDraftTokens: numDraftTokens
1572+
)
1573+
} else if !skipPromptCache, let cachedCount = await promptCache.restore(newTokens: promptTokens, into: cache) {
15301574
// Cache hit: KV state is pre-populated up to cachedCount tokens.
15311575
// Only compute the remaining (new) tokens.
15321576
var startIndex = cachedCount

0 commit comments

Comments
 (0)