From d38fe8ed543d29d34e704905b4d595503ef382de Mon Sep 17 00:00:00 2001 From: Eric Date: Wed, 22 Apr 2026 04:05:04 -0700 Subject: [PATCH 1/2] Re-apply prompt-cache bleed fixes to synced main MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three fixes, now riding on upstream 116ee91: 1. save(): slice KVCacheSimple state T-dim down to P=tokens.count so the cached states' T matches cached.tokens.count. Prevents the over-allocated prefill buffer from carrying uninitialized tokens past the valid prefix. 2. restore(): gate out recurrent-state layers (MambaCache and friends) up front. Their state is 2-D with no T dimension, so the dim(2) read in the pre-flight check would crash; also there's no trim(excess) operator for a recurrent hidden state — we can't partial-restore one safely. Guard with ndim>=3 inside the min-length scan too for belt-and-suspenders. 3. handleChatCompletion(): reorder the decision branch so speculative decoding is checked BEFORE the prompt cache restore. A cache-hit rollback corrupts the draft model's KV state (draft and main cycle tokens in lock-step), so when draftModelRef is set we bypass the cache entirely and pay the full prefill. Partial-match restores stay available on the non-spec path where they still pay off. --- Sources/SwiftLM/Server.swift | 56 +++++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/Sources/SwiftLM/Server.swift b/Sources/SwiftLM/Server.swift index 17d68d37..f06ca258 100644 --- a/Sources/SwiftLM/Server.swift +++ b/Sources/SwiftLM/Server.swift @@ -946,7 +946,24 @@ actor PromptCache { /// If not materialized now, those lazy references point to the live cache tensors /// which get overwritten by subsequent requests, causing stale data / SIGTRAP on restore. func save(tokens: [Int], cache: [KVCache]) { - let states = cache.map { $0.state } + let P = tokens.count + // For attention KVCacheSimple layers, the state tensor is [B, H, T, D] with a + // pre-allocated T that can exceed the actual prompt length P. If we store the + // full over-sized buffer, restore()'s trim() by (cached.tokens.count - matchLen) + // still leaves T - P slots of garbage beyond the valid prefix. Slice T to P at + // save time so cached.tokens.count === cached state's T. + let states: [[MLXArray]] = cache.map { layer -> [MLXArray] in + let s = layer.state + if layer is KVCacheSimple { + return s.map { arr -> MLXArray in + guard arr.ndim >= 3 else { return arr } + let T = arr.dim(2) + if T > P { return arr[.ellipsis, ..