@@ -772,35 +772,54 @@ actor ServerStats {
772772
773773actor PromptCache {
774774 struct CachedState {
775+ let tokens : [ Int ] // Full token sequence that generated this KV state
775776 let states : [ [ MLXArray ] ] // Per-layer KV state arrays
776777 let metaStates : [ [ String ] ] // Per-layer metadata
777- let tokenCount : Int // Number of cached tokens
778778 }
779779
780780 private var cached : CachedState ?
781- private var cachedTokenHash : Int ?
782781 private var hits : Int = 0
783782 private var misses : Int = 0
784783
785- func save( tokenHash: Int , cache: [ KVCache ] , tokenCount: Int ) {
784+ /// Save the full prompt token sequence and its KV state.
785+ func save( tokens: [ Int ] , cache: [ KVCache ] ) {
786786 let states = cache. map { $0. state }
787787 let metaStates = cache. map { $0. metaState }
788- cached = CachedState ( states: states, metaStates: metaStates, tokenCount: tokenCount)
789- cachedTokenHash = tokenHash
788+ cached = CachedState ( tokens: tokens, states: states, metaStates: metaStates)
790789 }
791790
792- func restore( tokenHash: Int , into cache: [ KVCache ] ) -> Int ? {
793- guard let cached, cachedTokenHash == tokenHash else {
791+ /// Find the longest common prefix between `newTokens` and the cached sequence.
792+ /// Restores matched KV state, trims any excess — mirrors llama-server behaviour.
793+ /// Returns the number of matched tokens, or nil on a complete miss.
794+ func restore( newTokens: [ Int ] , into cache: [ KVCache ] ) -> Int ? {
795+ guard let cached, !cached. tokens. isEmpty else {
794796 misses += 1
795797 return nil
796798 }
799+ // Token-by-token longest common prefix scan
800+ var matchLen = 0
801+ for (a, b) in zip ( cached. tokens, newTokens) {
802+ guard a == b else { break }
803+ matchLen += 1
804+ }
805+ guard matchLen > 0 else {
806+ misses += 1
807+ return nil
808+ }
809+ // Restore full cached KV state into each layer
797810 for i in 0 ..< min ( cache. count, cached. states. count) {
798811 var layer = cache [ i]
799812 layer. state = cached. states [ i]
800813 layer. metaState = cached. metaStates [ i]
801814 }
815+ // Trim excess if we only matched a partial prefix
816+ let excess = cached. tokens. count - matchLen
817+ if excess > 0 {
818+ for layer in cache { layer. trim ( excess) }
819+ }
802820 hits += 1
803- return cached. tokenCount
821+ print ( " [SwiftLM] \u{1F5C2} Prompt cache HIT: \( matchLen) / \( newTokens. count) tokens reused ( \( excess > 0 ? " partial " : " full " ) match) " )
822+ return matchLen
804823 }
805824
806825 func stats( ) -> ( hits: Int , misses: Int ) { ( hits, misses) }
@@ -910,9 +929,9 @@ func handleChatCompletion(
910929 let userInput = UserInput ( chat: chatMessages, tools: toolSpecs, additionalContext: templateContext)
911930 let lmInput = try await container. prepare ( input: userInput)
912931
913- // ── Prompt caching: compute hash for system prompt ──
932+ // ── Prompt caching: full token sequence for prefix matching ──
914933 let promptTokenCount = lmInput. text. tokens. size
915- let systemHash = systemPromptText . hashValue
934+ let promptTokens = lmInput . text . tokens . asArray ( Int . self )
916935
917936 // llama-server style: announce prefill start
918937 print ( " srv slot_launch: id 0 | prompt= \( promptTokenCount) t | thinking= \( enableThinking) | prefilling... " )
@@ -934,45 +953,26 @@ func handleChatCompletion(
934953 }
935954 }
936955
937- // Try to restore cached system prompt KV state
938- if let cachedCount = await promptCache. restore ( tokenHash: systemHash, into: cache) {
939- // Cache hit: skip the cached prefix tokens, process only the rest
956+ // Try to restore via token-by-token prefix match (llama-server style)
957+ if let cachedCount = await promptCache. restore ( newTokens: promptTokens, into: cache) {
958+ // Cache hit: KV state is pre-populated up to cachedCount tokens.
959+ // Only compute the remaining (new) tokens.
940960 let remainingTokens = lmInput. text. tokens [ cachedCount... ]
941961 let trimmedInput = LMInput ( tokens: remainingTokens)
942962 return try MLXLMCommon . generate (
943963 input: trimmedInput, cache: cache, parameters: params, context: context
944964 )
945965 } else {
946- // Cache miss: process everything, then save system prompt state
947- // Count system prompt tokens using the tokenizer
948- var systemTokenCount = 0
949- if !systemPromptText. isEmpty {
950- // Approximate system token count from the tokenizer
951- let sysTokens = context. tokenizer. encode ( text: systemPromptText)
952- // Add overhead for chat template tokens (BOS, role markers, etc.)
953- systemTokenCount = sysTokens. count + 4
954- }
955-
966+ // Cache miss: process the full prompt.
956967 let stream = try MLXLMCommon . generate (
957968 input: lmInput, cache: cache, parameters: params, context: context
958969 )
959-
960- // Save cache state after prefill (cache now contains all prompt tokens)
961- if systemTokenCount > 0 {
962- // Save the full prompt cache, but record the system token count
963- // so future requests with different user messages can still benefit
964- // (they'll have the system prefix cached)
965- //
966- // Note: We save after generate() starts, which means the cache
967- // has been populated by the TokenIterator's prepare() call.
968- // We use Task to save asynchronously after the first token.
969- Task {
970- // Small delay to let prefill complete and populate the cache
971- try ? await Task . sleep ( for: . milliseconds( 100 ) )
972- await promptCache. save ( tokenHash: systemHash, cache: cache, tokenCount: systemTokenCount)
973- }
970+ // Save full prompt tokens + KV state so the next request can prefix-match
971+ // any shared prefix (system prompt, conversation history, long documents, etc.)
972+ Task {
973+ try ? await Task . sleep ( for: . milliseconds( 100 ) )
974+ await promptCache. save ( tokens: promptTokens, cache: cache)
974975 }
975-
976976 return stream
977977 }
978978 }
0 commit comments