Skip to content

Commit 01df003

Browse files
committed
feat(prompt-cache): token-by-token prefix match (llama-server style)
Replace hash-based system-prompt cache with longest-common-prefix scan: - Store full token sequence alongside KV state - On each request: scan token-by-token to find longest shared prefix - Restore KV state, trim excess via layer.trim() for partial matches - Save full prompt after every request (not just system-prompt) Benefits: System prompt matched exactly, no token-count approximation bug Conversation history reuse (any shared prefix, not just system prompt) Partial prefix matches (e.g. same system + first N turns) also benefit Works correctly with TurboKV (state getter now returns full fp16 context)
1 parent 6ede853 commit 01df003

File tree

1 file changed

+39
-39
lines changed

1 file changed

+39
-39
lines changed

Sources/SwiftLM/Server.swift

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -772,35 +772,54 @@ actor ServerStats {
772772

773773
actor PromptCache {
774774
struct CachedState {
775+
let tokens: [Int] // Full token sequence that generated this KV state
775776
let states: [[MLXArray]] // Per-layer KV state arrays
776777
let metaStates: [[String]] // Per-layer metadata
777-
let tokenCount: Int // Number of cached tokens
778778
}
779779

780780
private var cached: CachedState?
781-
private var cachedTokenHash: Int?
782781
private var hits: Int = 0
783782
private var misses: Int = 0
784783

785-
func save(tokenHash: Int, cache: [KVCache], tokenCount: Int) {
784+
/// Save the full prompt token sequence and its KV state.
785+
func save(tokens: [Int], cache: [KVCache]) {
786786
let states = cache.map { $0.state }
787787
let metaStates = cache.map { $0.metaState }
788-
cached = CachedState(states: states, metaStates: metaStates, tokenCount: tokenCount)
789-
cachedTokenHash = tokenHash
788+
cached = CachedState(tokens: tokens, states: states, metaStates: metaStates)
790789
}
791790

792-
func restore(tokenHash: Int, into cache: [KVCache]) -> Int? {
793-
guard let cached, cachedTokenHash == tokenHash else {
791+
/// Find the longest common prefix between `newTokens` and the cached sequence.
792+
/// Restores matched KV state, trims any excess — mirrors llama-server behaviour.
793+
/// Returns the number of matched tokens, or nil on a complete miss.
794+
func restore(newTokens: [Int], into cache: [KVCache]) -> Int? {
795+
guard let cached, !cached.tokens.isEmpty else {
794796
misses += 1
795797
return nil
796798
}
799+
// Token-by-token longest common prefix scan
800+
var matchLen = 0
801+
for (a, b) in zip(cached.tokens, newTokens) {
802+
guard a == b else { break }
803+
matchLen += 1
804+
}
805+
guard matchLen > 0 else {
806+
misses += 1
807+
return nil
808+
}
809+
// Restore full cached KV state into each layer
797810
for i in 0..<min(cache.count, cached.states.count) {
798811
var layer = cache[i]
799812
layer.state = cached.states[i]
800813
layer.metaState = cached.metaStates[i]
801814
}
815+
// Trim excess if we only matched a partial prefix
816+
let excess = cached.tokens.count - matchLen
817+
if excess > 0 {
818+
for layer in cache { layer.trim(excess) }
819+
}
802820
hits += 1
803-
return cached.tokenCount
821+
print("[SwiftLM] \u{1F5C2} Prompt cache HIT: \(matchLen)/\(newTokens.count) tokens reused (\(excess > 0 ? "partial" : "full") match)")
822+
return matchLen
804823
}
805824

806825
func stats() -> (hits: Int, misses: Int) { (hits, misses) }
@@ -910,9 +929,9 @@ func handleChatCompletion(
910929
let userInput = UserInput(chat: chatMessages, tools: toolSpecs, additionalContext: templateContext)
911930
let lmInput = try await container.prepare(input: userInput)
912931

913-
// ── Prompt caching: compute hash for system prompt ──
932+
// ── Prompt caching: full token sequence for prefix matching ──
914933
let promptTokenCount = lmInput.text.tokens.size
915-
let systemHash = systemPromptText.hashValue
934+
let promptTokens = lmInput.text.tokens.asArray(Int.self)
916935

917936
// llama-server style: announce prefill start
918937
print("srv slot_launch: id 0 | prompt=\(promptTokenCount)t | thinking=\(enableThinking) | prefilling...")
@@ -934,45 +953,26 @@ func handleChatCompletion(
934953
}
935954
}
936955

937-
// Try to restore cached system prompt KV state
938-
if let cachedCount = await promptCache.restore(tokenHash: systemHash, into: cache) {
939-
// Cache hit: skip the cached prefix tokens, process only the rest
956+
// Try to restore via token-by-token prefix match (llama-server style)
957+
if let cachedCount = await promptCache.restore(newTokens: promptTokens, into: cache) {
958+
// Cache hit: KV state is pre-populated up to cachedCount tokens.
959+
// Only compute the remaining (new) tokens.
940960
let remainingTokens = lmInput.text.tokens[cachedCount...]
941961
let trimmedInput = LMInput(tokens: remainingTokens)
942962
return try MLXLMCommon.generate(
943963
input: trimmedInput, cache: cache, parameters: params, context: context
944964
)
945965
} else {
946-
// Cache miss: process everything, then save system prompt state
947-
// Count system prompt tokens using the tokenizer
948-
var systemTokenCount = 0
949-
if !systemPromptText.isEmpty {
950-
// Approximate system token count from the tokenizer
951-
let sysTokens = context.tokenizer.encode(text: systemPromptText)
952-
// Add overhead for chat template tokens (BOS, role markers, etc.)
953-
systemTokenCount = sysTokens.count + 4
954-
}
955-
966+
// Cache miss: process the full prompt.
956967
let stream = try MLXLMCommon.generate(
957968
input: lmInput, cache: cache, parameters: params, context: context
958969
)
959-
960-
// Save cache state after prefill (cache now contains all prompt tokens)
961-
if systemTokenCount > 0 {
962-
// Save the full prompt cache, but record the system token count
963-
// so future requests with different user messages can still benefit
964-
// (they'll have the system prefix cached)
965-
//
966-
// Note: We save after generate() starts, which means the cache
967-
// has been populated by the TokenIterator's prepare() call.
968-
// We use Task to save asynchronously after the first token.
969-
Task {
970-
// Small delay to let prefill complete and populate the cache
971-
try? await Task.sleep(for: .milliseconds(100))
972-
await promptCache.save(tokenHash: systemHash, cache: cache, tokenCount: systemTokenCount)
973-
}
970+
// Save full prompt tokens + KV state so the next request can prefix-match
971+
// any shared prefix (system prompt, conversation history, long documents, etc.)
972+
Task {
973+
try? await Task.sleep(for: .milliseconds(100))
974+
await promptCache.save(tokens: promptTokens, cache: cache)
974975
}
975-
976976
return stream
977977
}
978978
}

0 commit comments

Comments
 (0)