Skip to content

Commit c7c0ec0

Browse files
authored
Merge pull request #61 from leehack/perf/native-inference-optimization
perf(native): improve inference hot paths and add parity tooling
2 parents b252507 + 9505d29 commit c7c0ec0

22 files changed

Lines changed: 1561 additions & 58 deletions

.github/workflows/ci.yml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,44 @@ jobs:
140140
env:
141141
GGML_METAL_DEVICES: "0"
142142
run: dart test -p vm -j 1 --exclude-tags local-only
143+
144+
native-prompt-reuse-parity:
145+
name: Native Prompt Reuse Parity
146+
runs-on: ubuntu-latest
147+
timeout-minutes: 35
148+
steps:
149+
- uses: actions/checkout@v4
150+
151+
- name: Setup Flutter
152+
uses: subosito/flutter-action@v2
153+
with:
154+
channel: 'stable'
155+
cache: true
156+
157+
- name: Install dependencies
158+
run: flutter pub get
159+
160+
- name: Download parity model
161+
run: |
162+
mkdir -p "$RUNNER_TEMP/models"
163+
curl --retry 5 --retry-all-errors --retry-delay 3 --location \
164+
"https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/qwen2.5-0.5b-instruct-q4_k_m.gguf?download=true" \
165+
--output "$RUNNER_TEMP/models/qwen2.5-0.5b-instruct-q4_k_m.gguf"
166+
167+
- name: Run native prompt reuse parity
168+
run: |
169+
dart run tool/testing/native_prompt_reuse_parity.dart \
170+
--model "$RUNNER_TEMP/models/qwen2.5-0.5b-instruct-q4_k_m.gguf" \
171+
--prompt-file "tool/testing/prompts/native_prompt_reuse_ci_prompts.txt" \
172+
--max-prompts 4 \
173+
--runs 1 \
174+
--max-tokens 64 \
175+
--gpu-layers 0 \
176+
--threads 1 \
177+
--threads-batch 1 \
178+
--temp 0 \
179+
--top-k 1 \
180+
--top-p 1.0 \
181+
--min-p 0.0 \
182+
--repeat-penalty 1.0 \
183+
--fail-on-mismatch

CHANGELOG.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
1+
## 0.6.2
2+
3+
* **Native inference performance improvements**:
4+
* Reduced request overhead by caching model metadata and skipping
5+
unnecessary prompt token counting in `create(...)`.
6+
* Improved native stream throughput with worker-side token chunk batching
7+
and configurable thresholds (`streamBatchTokenThreshold`,
8+
`streamBatchByteThreshold`).
9+
* Added prompt-prefix reuse for native text generation
10+
(`reusePromptPrefix`, enabled by default) with conservative full-replay
11+
fallback to preserve deterministic parity.
12+
* Optimized `ChatSession` context trimming using bounded turn-offset
13+
search to avoid repeated linear recount loops on long histories.
14+
* **Benchmarking and parity tooling**:
15+
* Added `tool/testing/native_inference_benchmark.dart` for TTFT,
16+
throughput, and latency measurement with tunable generation settings.
17+
* Added `tool/testing/native_prompt_reuse_parity.dart` and curated prompt
18+
sets for deterministic prompt-reuse parity validation.
19+
* Added CI prompt-reuse parity checks to catch native reuse regressions.
20+
121
## 0.6.1
222

323
* **Publishing compatibility fix**:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
```yaml
3737
dependencies:
38-
llamadart: ^0.6.1
38+
llamadart: ^0.6.2
3939
```
4040
4141
### 2. Run with defaults

example/chat_app/test/mocks.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ class MockLlamaEngine extends LlamaEngine {
141141
String? customTemplate,
142142
String? sourceLangCode,
143143
String? targetLangCode,
144+
bool includeTokenCount = true,
144145
Map<String, dynamic>? chatTemplateKwargs,
145146
DateTime? templateNow,
146147
}) async {

lib/src/backends/llama_cpp/llama_cpp_service.dart

Lines changed: 153 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,9 +1655,16 @@ class LlamaCppService {
16551655
final model = _models[modelHandle]!;
16561656
final modelParams = _contextParams[contextHandle]!;
16571657
final vocab = llama_model_get_vocab(model.pointer);
1658+
final hasMediaParts =
1659+
parts?.any((p) => p is LlamaImageContent || p is LlamaAudioContent) ??
1660+
false;
16581661

16591662
// 1. Reset Context
1660-
ctx = _resetContext(contextHandle, ctx);
1663+
ctx = _resetContext(
1664+
contextHandle,
1665+
ctx,
1666+
clearMemory: hasMediaParts || !params.reusePromptPrefix,
1667+
);
16611668

16621669
// 2. Prepare Resources
16631670
final nCtx = llama_n_ctx(ctx.pointer);
@@ -1689,6 +1696,7 @@ class LlamaCppService {
16891696
tokensPtr,
16901697
nCtx,
16911698
modelParams,
1699+
allowTextPromptReuse: !hasMediaParts && params.reusePromptPrefix,
16921700
);
16931701

16941702
// 4. Initialize and Run Sampler Loop
@@ -1739,18 +1747,27 @@ class LlamaCppService {
17391747
/// Helper: Resets the context state to be ready for new generation.
17401748
_LlamaContextWrapper _resetContext(
17411749
int contextHandle,
1742-
_LlamaContextWrapper ctx,
1743-
) {
1750+
_LlamaContextWrapper ctx, {
1751+
required bool clearMemory,
1752+
}) {
17441753
llama_synchronize(ctx.pointer);
17451754

1746-
final memory = llama_get_memory(ctx.pointer);
1755+
if (clearMemory) {
1756+
_clearContextMemory(ctx.pointer);
1757+
ctx.cachedPromptTokens = null;
1758+
}
1759+
1760+
_contexts[contextHandle] = ctx;
1761+
return ctx;
1762+
}
1763+
1764+
void _clearContextMemory(Pointer<llama_context> contextPointer) {
1765+
final memory = llama_get_memory(contextPointer);
17471766
if (memory == nullptr) {
17481767
throw Exception("Failed to reset context memory");
17491768
}
17501769

17511770
llama_memory_clear(memory, true);
1752-
_contexts[contextHandle] = ctx;
1753-
return ctx;
17541771
}
17551772

17561773
/// Helper: Ingests the prompt (text or multimodal) and returns initial token count.
@@ -1764,8 +1781,9 @@ class LlamaCppService {
17641781
List<LlamaContentPart>? parts,
17651782
Pointer<Int32> tokensPtr,
17661783
int nCtx,
1767-
llama_context_params modelParams,
1768-
) {
1784+
llama_context_params modelParams, {
1785+
required bool allowTextPromptReuse,
1786+
}) {
17691787
final mediaParts =
17701788
parts
17711789
?.where((p) => p is LlamaImageContent || p is LlamaAudioContent)
@@ -1784,7 +1802,15 @@ class LlamaCppService {
17841802
modelParams,
17851803
);
17861804
} else {
1787-
return _ingestTextPrompt(batch, vocab, prompt, tokensPtr, nCtx, ctx);
1805+
return _ingestTextPrompt(
1806+
batch,
1807+
vocab,
1808+
prompt,
1809+
tokensPtr,
1810+
nCtx,
1811+
ctx,
1812+
allowPromptReuse: allowTextPromptReuse,
1813+
);
17881814
}
17891815
}
17901816

@@ -1902,6 +1928,7 @@ class LlamaCppService {
19021928
malloc.free(bitmaps);
19031929
_mtmdInputChunksFree(chunks);
19041930
}
1931+
ctx.cachedPromptTokens = null;
19051932
return initialTokens;
19061933
}
19071934

@@ -1991,8 +2018,9 @@ class LlamaCppService {
19912018
String prompt,
19922019
Pointer<Int32> tokensPtr,
19932020
int nCtx,
1994-
_LlamaContextWrapper ctx,
1995-
) {
2021+
_LlamaContextWrapper ctx, {
2022+
required bool allowPromptReuse,
2023+
}) {
19962024
final promptPtr = prompt.toNativeUtf8();
19972025
final shouldAddSpecial = !_promptStartsWithBosToken(vocab, prompt);
19982026
final nTokens = llama_tokenize(
@@ -2010,20 +2038,123 @@ class LlamaCppService {
20102038
throw Exception("Tokenization failed or prompt too long");
20112039
}
20122040

2013-
batch.n_tokens = nTokens;
2014-
for (int i = 0; i < nTokens; i++) {
2015-
batch.token[i] = tokensPtr[i];
2016-
batch.pos[i] = i;
2041+
if (!allowPromptReuse || nTokens == 0) {
2042+
return _decodeAndCacheFullPrompt(batch, tokensPtr, ctx, nTokens);
2043+
}
2044+
2045+
final cachedTokens = ctx.cachedPromptTokens;
2046+
if (cachedTokens == null || cachedTokens.isEmpty) {
2047+
return _decodeAndCacheFullPrompt(batch, tokensPtr, ctx, nTokens);
2048+
}
2049+
2050+
final reusedPrefix = _sharedPrefixLength(cachedTokens, tokensPtr, nTokens);
2051+
2052+
if (reusedPrefix <= 0 || reusedPrefix >= nTokens) {
2053+
final canReuseCachedCopy =
2054+
reusedPrefix == nTokens && cachedTokens.length == nTokens;
2055+
return _decodeAndCacheFullPrompt(
2056+
batch,
2057+
tokensPtr,
2058+
ctx,
2059+
nTokens,
2060+
existingCachedTokens: canReuseCachedCopy ? cachedTokens : null,
2061+
);
2062+
}
2063+
2064+
final memory = llama_get_memory(ctx.pointer);
2065+
if (memory == nullptr) {
2066+
return _decodeAndCacheFullPrompt(batch, tokensPtr, ctx, nTokens);
2067+
}
2068+
2069+
final decodeStart = reusedPrefix;
2070+
2071+
final maxSeqPos = llama_memory_seq_pos_max(memory, 0);
2072+
final removeTo = maxSeqPos >= decodeStart ? maxSeqPos + 1 : decodeStart;
2073+
final removedTail = llama_memory_seq_rm(memory, 0, decodeStart, removeTo);
2074+
if (!removedTail) {
2075+
return _decodeAndCacheFullPrompt(batch, tokensPtr, ctx, nTokens);
2076+
}
2077+
2078+
final suffixTokenCount = nTokens - decodeStart;
2079+
_decodePromptSegment(
2080+
batch,
2081+
tokensPtr,
2082+
ctx,
2083+
startTokenIndex: decodeStart,
2084+
tokenCount: suffixTokenCount,
2085+
);
2086+
2087+
ctx.cachedPromptTokens = _copyPromptTokens(tokensPtr, nTokens);
2088+
2089+
return nTokens;
2090+
}
2091+
2092+
int _decodeAndCacheFullPrompt(
2093+
llama_batch batch,
2094+
Pointer<Int32> tokensPtr,
2095+
_LlamaContextWrapper ctx,
2096+
int nTokens, {
2097+
List<int>? existingCachedTokens,
2098+
}) {
2099+
_clearContextMemory(ctx.pointer);
2100+
_decodePromptSegment(
2101+
batch,
2102+
tokensPtr,
2103+
ctx,
2104+
startTokenIndex: 0,
2105+
tokenCount: nTokens,
2106+
);
2107+
ctx.cachedPromptTokens =
2108+
existingCachedTokens ?? _copyPromptTokens(tokensPtr, nTokens);
2109+
return nTokens;
2110+
}
2111+
2112+
List<int> _copyPromptTokens(Pointer<Int32> tokensPtr, int tokenCount) {
2113+
if (tokenCount <= 0) {
2114+
return const <int>[];
2115+
}
2116+
return List<int>.from(tokensPtr.asTypedList(tokenCount), growable: false);
2117+
}
2118+
2119+
void _decodePromptSegment(
2120+
llama_batch batch,
2121+
Pointer<Int32> tokensPtr,
2122+
_LlamaContextWrapper ctx, {
2123+
required int startTokenIndex,
2124+
required int tokenCount,
2125+
}) {
2126+
if (tokenCount <= 0) {
2127+
return;
2128+
}
2129+
2130+
batch.n_tokens = tokenCount;
2131+
for (int i = 0; i < tokenCount; i++) {
2132+
final tokenIndex = startTokenIndex + i;
2133+
batch.token[i] = tokensPtr[tokenIndex];
2134+
batch.pos[i] = tokenIndex;
20172135
batch.n_seq_id[i] = 1;
20182136
batch.seq_id[i][0] = 0;
2019-
batch.logits[i] = (i == nTokens - 1) ? 1 : 0;
2137+
batch.logits[i] = (i == tokenCount - 1) ? 1 : 0;
20202138
}
20212139

20222140
if (llama_decode(ctx.pointer, batch) != 0) {
20232141
throw Exception("Initial decode failed");
20242142
}
2143+
}
20252144

2026-
return nTokens;
2145+
int _sharedPrefixLength(
2146+
List<int> cachedTokens,
2147+
Pointer<Int32> newTokens,
2148+
int newTokenCount,
2149+
) {
2150+
final maxLength = cachedTokens.length < newTokenCount
2151+
? cachedTokens.length
2152+
: newTokenCount;
2153+
int i = 0;
2154+
while (i < maxLength && cachedTokens[i] == newTokens[i]) {
2155+
i++;
2156+
}
2157+
return i;
20272158
}
20282159

20292160
/// Helper: Initializes the sampler chain.
@@ -2113,7 +2244,6 @@ class LlamaCppService {
21132244
final accumulatedBytes = <int>[];
21142245

21152246
for (int i = 0; i < params.maxTokens; i++) {
2116-
await Future.delayed(Duration.zero);
21172247
if (cancelToken.value == 1) break;
21182248
if (currentPos >= nCtx) break;
21192249

@@ -2412,15 +2542,18 @@ class LlamaCppService {
24122542
}
24132543
activeLoras[path] = scale;
24142544
_applyActiveLoras(ctx.pointer, modelAdapters, activeLoras);
2545+
ctx.cachedPromptTokens = null;
24152546
} else if (op == 'remove') {
24162547
if (path == null) {
24172548
throw Exception('LoRA path is required for remove operation');
24182549
}
24192550
activeLoras.remove(path);
24202551
_applyActiveLoras(ctx.pointer, modelAdapters, activeLoras);
2552+
ctx.cachedPromptTokens = null;
24212553
} else if (op == 'clear') {
24222554
activeLoras.clear();
24232555
_applyActiveLoras(ctx.pointer, modelAdapters, activeLoras);
2556+
ctx.cachedPromptTokens = null;
24242557
} else {
24252558
throw Exception('Unknown LoRA operation: $op');
24262559
}
@@ -3090,10 +3223,12 @@ class _LlamaModelWrapper {
30903223
class _LlamaContextWrapper {
30913224
final Pointer<llama_context> pointer;
30923225
final _LlamaModelWrapper? _modelKeepAlive;
3226+
List<int>? cachedPromptTokens;
30933227
_LlamaContextWrapper(this.pointer, this._modelKeepAlive);
30943228
void dispose() {
30953229
// ignore: unused_local_variable
30963230
final _ = _modelKeepAlive;
3231+
cachedPromptTokens = null;
30973232
llama_free(pointer);
30983233
}
30993234
}

0 commit comments

Comments
 (0)