@@ -1655,9 +1655,16 @@ class LlamaCppService {
16551655 final model = _models[modelHandle]! ;
16561656 final modelParams = _contextParams[contextHandle]! ;
16571657 final vocab = llama_model_get_vocab (model.pointer);
1658+ final hasMediaParts =
1659+ parts? .any ((p) => p is LlamaImageContent || p is LlamaAudioContent ) ??
1660+ false ;
16581661
16591662 // 1. Reset Context
1660- ctx = _resetContext (contextHandle, ctx);
1663+ ctx = _resetContext (
1664+ contextHandle,
1665+ ctx,
1666+ clearMemory: hasMediaParts || ! params.reusePromptPrefix,
1667+ );
16611668
16621669 // 2. Prepare Resources
16631670 final nCtx = llama_n_ctx (ctx.pointer);
@@ -1689,6 +1696,7 @@ class LlamaCppService {
16891696 tokensPtr,
16901697 nCtx,
16911698 modelParams,
1699+ allowTextPromptReuse: ! hasMediaParts && params.reusePromptPrefix,
16921700 );
16931701
16941702 // 4. Initialize and Run Sampler Loop
@@ -1739,18 +1747,27 @@ class LlamaCppService {
17391747 /// Helper: Resets the context state to be ready for new generation.
17401748 _LlamaContextWrapper _resetContext (
17411749 int contextHandle,
1742- _LlamaContextWrapper ctx,
1743- ) {
1750+ _LlamaContextWrapper ctx, {
1751+ required bool clearMemory,
1752+ }) {
17441753 llama_synchronize (ctx.pointer);
17451754
1746- final memory = llama_get_memory (ctx.pointer);
1755+ if (clearMemory) {
1756+ _clearContextMemory (ctx.pointer);
1757+ ctx.cachedPromptTokens = null ;
1758+ }
1759+
1760+ _contexts[contextHandle] = ctx;
1761+ return ctx;
1762+ }
1763+
1764+ void _clearContextMemory (Pointer <llama_context> contextPointer) {
1765+ final memory = llama_get_memory (contextPointer);
17471766 if (memory == nullptr) {
17481767 throw Exception ("Failed to reset context memory" );
17491768 }
17501769
17511770 llama_memory_clear (memory, true );
1752- _contexts[contextHandle] = ctx;
1753- return ctx;
17541771 }
17551772
17561773 /// Helper: Ingests the prompt (text or multimodal) and returns initial token count.
@@ -1764,8 +1781,9 @@ class LlamaCppService {
17641781 List <LlamaContentPart >? parts,
17651782 Pointer <Int32 > tokensPtr,
17661783 int nCtx,
1767- llama_context_params modelParams,
1768- ) {
1784+ llama_context_params modelParams, {
1785+ required bool allowTextPromptReuse,
1786+ }) {
17691787 final mediaParts =
17701788 parts
17711789 ? .where ((p) => p is LlamaImageContent || p is LlamaAudioContent )
@@ -1784,7 +1802,15 @@ class LlamaCppService {
17841802 modelParams,
17851803 );
17861804 } else {
1787- return _ingestTextPrompt (batch, vocab, prompt, tokensPtr, nCtx, ctx);
1805+ return _ingestTextPrompt (
1806+ batch,
1807+ vocab,
1808+ prompt,
1809+ tokensPtr,
1810+ nCtx,
1811+ ctx,
1812+ allowPromptReuse: allowTextPromptReuse,
1813+ );
17881814 }
17891815 }
17901816
@@ -1902,6 +1928,7 @@ class LlamaCppService {
19021928 malloc.free (bitmaps);
19031929 _mtmdInputChunksFree (chunks);
19041930 }
1931+ ctx.cachedPromptTokens = null ;
19051932 return initialTokens;
19061933 }
19071934
@@ -1991,8 +2018,9 @@ class LlamaCppService {
19912018 String prompt,
19922019 Pointer <Int32 > tokensPtr,
19932020 int nCtx,
1994- _LlamaContextWrapper ctx,
1995- ) {
2021+ _LlamaContextWrapper ctx, {
2022+ required bool allowPromptReuse,
2023+ }) {
19962024 final promptPtr = prompt.toNativeUtf8 ();
19972025 final shouldAddSpecial = ! _promptStartsWithBosToken (vocab, prompt);
19982026 final nTokens = llama_tokenize (
@@ -2010,20 +2038,123 @@ class LlamaCppService {
20102038 throw Exception ("Tokenization failed or prompt too long" );
20112039 }
20122040
2013- batch.n_tokens = nTokens;
2014- for (int i = 0 ; i < nTokens; i++ ) {
2015- batch.token[i] = tokensPtr[i];
2016- batch.pos[i] = i;
2041+ if (! allowPromptReuse || nTokens == 0 ) {
2042+ return _decodeAndCacheFullPrompt (batch, tokensPtr, ctx, nTokens);
2043+ }
2044+
2045+ final cachedTokens = ctx.cachedPromptTokens;
2046+ if (cachedTokens == null || cachedTokens.isEmpty) {
2047+ return _decodeAndCacheFullPrompt (batch, tokensPtr, ctx, nTokens);
2048+ }
2049+
2050+ final reusedPrefix = _sharedPrefixLength (cachedTokens, tokensPtr, nTokens);
2051+
2052+ if (reusedPrefix <= 0 || reusedPrefix >= nTokens) {
2053+ final canReuseCachedCopy =
2054+ reusedPrefix == nTokens && cachedTokens.length == nTokens;
2055+ return _decodeAndCacheFullPrompt (
2056+ batch,
2057+ tokensPtr,
2058+ ctx,
2059+ nTokens,
2060+ existingCachedTokens: canReuseCachedCopy ? cachedTokens : null ,
2061+ );
2062+ }
2063+
2064+ final memory = llama_get_memory (ctx.pointer);
2065+ if (memory == nullptr) {
2066+ return _decodeAndCacheFullPrompt (batch, tokensPtr, ctx, nTokens);
2067+ }
2068+
2069+ final decodeStart = reusedPrefix;
2070+
2071+ final maxSeqPos = llama_memory_seq_pos_max (memory, 0 );
2072+ final removeTo = maxSeqPos >= decodeStart ? maxSeqPos + 1 : decodeStart;
2073+ final removedTail = llama_memory_seq_rm (memory, 0 , decodeStart, removeTo);
2074+ if (! removedTail) {
2075+ return _decodeAndCacheFullPrompt (batch, tokensPtr, ctx, nTokens);
2076+ }
2077+
2078+ final suffixTokenCount = nTokens - decodeStart;
2079+ _decodePromptSegment (
2080+ batch,
2081+ tokensPtr,
2082+ ctx,
2083+ startTokenIndex: decodeStart,
2084+ tokenCount: suffixTokenCount,
2085+ );
2086+
2087+ ctx.cachedPromptTokens = _copyPromptTokens (tokensPtr, nTokens);
2088+
2089+ return nTokens;
2090+ }
2091+
2092+ int _decodeAndCacheFullPrompt (
2093+ llama_batch batch,
2094+ Pointer <Int32 > tokensPtr,
2095+ _LlamaContextWrapper ctx,
2096+ int nTokens, {
2097+ List <int >? existingCachedTokens,
2098+ }) {
2099+ _clearContextMemory (ctx.pointer);
2100+ _decodePromptSegment (
2101+ batch,
2102+ tokensPtr,
2103+ ctx,
2104+ startTokenIndex: 0 ,
2105+ tokenCount: nTokens,
2106+ );
2107+ ctx.cachedPromptTokens =
2108+ existingCachedTokens ?? _copyPromptTokens (tokensPtr, nTokens);
2109+ return nTokens;
2110+ }
2111+
2112+ List <int > _copyPromptTokens (Pointer <Int32 > tokensPtr, int tokenCount) {
2113+ if (tokenCount <= 0 ) {
2114+ return const < int > [];
2115+ }
2116+ return List <int >.from (tokensPtr.asTypedList (tokenCount), growable: false );
2117+ }
2118+
2119+ void _decodePromptSegment (
2120+ llama_batch batch,
2121+ Pointer <Int32 > tokensPtr,
2122+ _LlamaContextWrapper ctx, {
2123+ required int startTokenIndex,
2124+ required int tokenCount,
2125+ }) {
2126+ if (tokenCount <= 0 ) {
2127+ return ;
2128+ }
2129+
2130+ batch.n_tokens = tokenCount;
2131+ for (int i = 0 ; i < tokenCount; i++ ) {
2132+ final tokenIndex = startTokenIndex + i;
2133+ batch.token[i] = tokensPtr[tokenIndex];
2134+ batch.pos[i] = tokenIndex;
20172135 batch.n_seq_id[i] = 1 ;
20182136 batch.seq_id[i][0 ] = 0 ;
2019- batch.logits[i] = (i == nTokens - 1 ) ? 1 : 0 ;
2137+ batch.logits[i] = (i == tokenCount - 1 ) ? 1 : 0 ;
20202138 }
20212139
20222140 if (llama_decode (ctx.pointer, batch) != 0 ) {
20232141 throw Exception ("Initial decode failed" );
20242142 }
2143+ }
20252144
2026- return nTokens;
2145+ int _sharedPrefixLength (
2146+ List <int > cachedTokens,
2147+ Pointer <Int32 > newTokens,
2148+ int newTokenCount,
2149+ ) {
2150+ final maxLength = cachedTokens.length < newTokenCount
2151+ ? cachedTokens.length
2152+ : newTokenCount;
2153+ int i = 0 ;
2154+ while (i < maxLength && cachedTokens[i] == newTokens[i]) {
2155+ i++ ;
2156+ }
2157+ return i;
20272158 }
20282159
20292160 /// Helper: Initializes the sampler chain.
@@ -2113,7 +2244,6 @@ class LlamaCppService {
21132244 final accumulatedBytes = < int > [];
21142245
21152246 for (int i = 0 ; i < params.maxTokens; i++ ) {
2116- await Future .delayed (Duration .zero);
21172247 if (cancelToken.value == 1 ) break ;
21182248 if (currentPos >= nCtx) break ;
21192249
@@ -2412,15 +2542,18 @@ class LlamaCppService {
24122542 }
24132543 activeLoras[path] = scale;
24142544 _applyActiveLoras (ctx.pointer, modelAdapters, activeLoras);
2545+ ctx.cachedPromptTokens = null ;
24152546 } else if (op == 'remove' ) {
24162547 if (path == null ) {
24172548 throw Exception ('LoRA path is required for remove operation' );
24182549 }
24192550 activeLoras.remove (path);
24202551 _applyActiveLoras (ctx.pointer, modelAdapters, activeLoras);
2552+ ctx.cachedPromptTokens = null ;
24212553 } else if (op == 'clear' ) {
24222554 activeLoras.clear ();
24232555 _applyActiveLoras (ctx.pointer, modelAdapters, activeLoras);
2556+ ctx.cachedPromptTokens = null ;
24242557 } else {
24252558 throw Exception ('Unknown LoRA operation: $op ' );
24262559 }
@@ -3090,10 +3223,12 @@ class _LlamaModelWrapper {
30903223class _LlamaContextWrapper {
30913224 final Pointer <llama_context> pointer;
30923225 final _LlamaModelWrapper ? _modelKeepAlive;
3226+ List <int >? cachedPromptTokens;
30933227 _LlamaContextWrapper (this .pointer, this ._modelKeepAlive);
30943228 void dispose () {
30953229 // ignore: unused_local_variable
30963230 final _ = _modelKeepAlive;
3231+ cachedPromptTokens = null ;
30973232 llama_free (pointer);
30983233 }
30993234}
0 commit comments