Skip to content

Commit 4f8bdd0

Browse files
fix/feat!: LLMs context management (software-mansion#819)
## Description This PR fixes few bugs related to the LLMs, caused by mixing two approaches - functional (as we pass whole messages history each time) and stateful (as we keep `pos_` in the runner, representing at which position the KV cache is), which resulted in 3 bugs: - broken KV cache for reasoning models - in the runner, we counted tokens generated for the reasoning and included these in KV cache (`pos_ += num_generated_tokens`), but in next turns, `jinja template` removed these reasoning tokens from the messages history - as a result, KV-cache was incoherent - duplicated tokens in KV cache - we were passing whole messages history to the runner (functional approach), but we were also appending all tokens (prompt and generated) to the KV cache (which position is represented by `pos_`) - as a result tokens were "duplicated" in the KV cache and we were running out of available tokens very fast (exceeding `context_window_length`) - stateful TS functional API - even though our `generate()` method is called functional, it kept internal state in the runner (e. g. `pos_`) These bugs were fixed by resetting the runner before each generation, which makes it truly functional - old messages are prefilled and the KV cache can be still used during generation phase. Additionally, this PR adds `ContextStrategy` to `ChatConfig` interface, so now it's possible to define (or use one of already implemented) strategy for managing context (e. g. naive, message count based, sliding window) - it gives us more flexibility and user can decide what's best for their use case. From now on, `SlidingWindowContextStrategy` is also configured as the default one. ### Introduces a breaking change? - [x] Yes - [ ] No These changes will not break anything until max number of messages is not modified (I removed `contextWindowLength` from `ChatConfig` and replaced it with `contextStrategy`) ### Type of change - [x] Bug fix (change which fixes an issue) - [x] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [ ] Other (chores, tests, code style improvements etc.) ### Tested on - [x] iOS - [x] Android ### Testing instructions Run example llm app, open executorch logs (`adb logcat | grep -i "executorch"` for example) and see if numbers of tokens are properly aligned and if `pos_` is correct. To test different context management strategies, change `contextStrategy` in llm app and modify model configuration. ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues software-mansion#776 ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [x] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes Position in KV cache, number of prompt tokens and number of generated tokens for both non-reasoning and reasoning models BEFORE changes. LLAMA 3.2 1B SPINQUANT (without reasoning) | pos_ | Prompt tokens | Generated tokens | |------------------|---------------|------------------| | 0 | 335 | 269 | | 604=269+335 | 872 | 372 | | 1848=604+872+372 | 1513 | CRASH | QWEN 3.0 0.6B QUANTIZED (with reasoning) | pos_ | Prompt tokens | Generated tokens | |------------------|---------------|------------------| | 0 | 309 | 457 | | 766=309+457 | 617 (<766!) | 192 | | 1575=766+617+192 | 925 (<1575!) | CRASH |
1 parent 31eca42 commit 4f8bdd0

271 files changed

Lines changed: 1474 additions & 915 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.cspell-wordlist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ microcontrollers
2121
notimestamps
2222
seqs
2323
smollm
24+
llms
2425
qwen
2526
XNNPACK
2627
EFFICIENTNET

docs/docs/03-hooks/01-natural-language-processing/useLLM.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ To configure model (i.e. change system prompt, load initial conversation history
192192

193193
- [`initialMessageHistory`](../../06-api-reference/interfaces/ChatConfig.md#initialmessagehistory) - Object that represent the conversation history. This can be used to provide initial context to the model.
194194

195-
- [`contextWindowLength`](../../06-api-reference/interfaces/ChatConfig.md#contextwindowlength) - The number of messages from the current conversation that the model will use to generate a response. Keep in mind that using larger context windows will result in longer inference time and higher memory usage.
195+
- [`contextStrategy`](../../06-api-reference/interfaces/ChatConfig.md#contextstrategy) - Object implementing [`ContextStrategy`](../../06-api-reference/interfaces/ContextStrategy.md) interface used to manage conversation context, including trimming history if necessary. Custom strategies can be implemented or one of the built-in options can be used (e.g. [`NoopContextStrategy`](../../06-api-reference/classes/NoopContextStrategy.md), [`MessageCountContextStrategy`](../../06-api-reference/classes/MessageCountContextStrategy.md) or the default [`SlidingWindowContextStrategy`](../../06-api-reference/classes/SlidingWindowContextStrategy.md)).
196196

197197
- [`toolsConfig`](../../06-api-reference/interfaces/LLMConfig.md#toolsconfig) - Object configuring options for enabling and managing tool use. **It will only have effect if your model's chat template support it**. Contains following properties:
198198
- [`tools`](../../06-api-reference/interfaces/ToolsConfig.md#tools) - List of objects defining tools.

docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ To configure model (i.e. change system prompt, load initial conversation history
9696

9797
- [`initialMessageHistory`](../../06-api-reference/interfaces/ChatConfig.md#initialmessagehistory) - Object that represent the conversation history. This can be used to provide initial context to the model.
9898

99-
- [`contextWindowLength`](../../06-api-reference/interfaces/ChatConfig.md#contextwindowlength) - The number of messages from the current conversation that the model will use to generate a response. Keep in mind that using larger context windows will result in longer inference time and higher memory usage.
99+
- [`contextStrategy`](../../06-api-reference/interfaces/ChatConfig.md#contextstrategy) - Object implementing [`ContextStrategy`](../../06-api-reference/interfaces/ContextStrategy.md) interface used to manage conversation context, including trimming history if necessary. Custom strategies can be implemented or one of the built-in options can be used (e.g. [`NoopContextStrategy`](../../06-api-reference/classes/NoopContextStrategy.md), [`MessageCountContextStrategy`](../../06-api-reference/classes/MessageCountContextStrategy.md) or the default [`SlidingWindowContextStrategy`](../../06-api-reference/classes/SlidingWindowContextStrategy.md)).
100100

101101
- [`toolsConfig`](../../06-api-reference/interfaces/ToolsConfig.md) - Object configuring options for enabling and managing tool use. **It will only have effect if your model's chat template support it**. Contains following properties:
102102
- [`tools`](../../06-api-reference/interfaces/ToolsConfig.md#tools) - List of objects defining tools.

docs/docs/06-api-reference/classes/ClassificationModule.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Class: ClassificationModule
22

3-
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts:12](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts#L12)
3+
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts:13](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts#L13)
44

55
Module for image classification tasks.
66

@@ -28,7 +28,7 @@ Module for image classification tasks.
2828

2929
> **nativeModule**: `any` = `null`
3030
31-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:8](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L8)
31+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:8](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/BaseModule.ts#L8)
3232

3333
Native module instance
3434

@@ -42,7 +42,7 @@ Native module instance
4242

4343
> **delete**(): `void`
4444
45-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:41](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L41)
45+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:41](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/BaseModule.ts#L41)
4646

4747
Unloads the model from memory.
4848

@@ -60,7 +60,7 @@ Unloads the model from memory.
6060

6161
> **forward**(`imageSource`): `Promise`\<\{\[`category`: `string`\]: `number`; \}\>
6262
63-
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts:43](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts#L43)
63+
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts:51](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts#L51)
6464

6565
Executes the model's forward pass, where `imageSource` can be a fetchable resource or a Base64-encoded string.
6666

@@ -84,7 +84,7 @@ The classification result.
8484

8585
> `protected` **forwardET**(`inputTensor`): `Promise`\<[`TensorPtr`](../interfaces/TensorPtr.md)[]\>
8686
87-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:23](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L23)
87+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:23](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/BaseModule.ts#L23)
8888

8989
Runs the model's forward method with the given input tensors.
9090
It returns the output tensors that mimic the structure of output from ExecuTorch.
@@ -113,7 +113,7 @@ Array of output tensors.
113113

114114
> **getInputShape**(`methodName`, `index`): `Promise`\<`number`[]\>
115115
116-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:34](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L34)
116+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:34](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/BaseModule.ts#L34)
117117

118118
Gets the input shape for a given method and index.
119119

@@ -147,7 +147,7 @@ The input shape as an array of numbers.
147147

148148
> **load**(`model`, `onDownloadProgressCallback`): `Promise`\<`void`\>
149149
150-
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts:20](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts#L20)
150+
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts:21](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts#L21)
151151

152152
Loads the model, where `modelSource` is a string that specifies the location of the model binary.
153153
To track the download progress, supply a callback function `onDownloadProgressCallback`.

docs/docs/06-api-reference/classes/ExecutorchModule.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Class: ExecutorchModule
22

3-
Defined in: [packages/react-native-executorch/src/modules/general/ExecutorchModule.ts:13](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts#L13)
3+
Defined in: [packages/react-native-executorch/src/modules/general/ExecutorchModule.ts:14](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts#L14)
44

55
General module for executing custom Executorch models.
66

@@ -28,7 +28,7 @@ General module for executing custom Executorch models.
2828

2929
> **nativeModule**: `any` = `null`
3030
31-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:8](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L8)
31+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:8](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/BaseModule.ts#L8)
3232

3333
Native module instance
3434

@@ -42,7 +42,7 @@ Native module instance
4242

4343
> **delete**(): `void`
4444
45-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:41](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L41)
45+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:41](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/BaseModule.ts#L41)
4646

4747
Unloads the model from memory.
4848

@@ -60,7 +60,7 @@ Unloads the model from memory.
6060

6161
> **forward**(`inputTensor`): `Promise`\<[`TensorPtr`](../interfaces/TensorPtr.md)[]\>
6262
63-
Defined in: [packages/react-native-executorch/src/modules/general/ExecutorchModule.ts:45](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts#L45)
63+
Defined in: [packages/react-native-executorch/src/modules/general/ExecutorchModule.ts:51](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts#L51)
6464

6565
Executes the model's forward pass, where input is an array of `TensorPtr` objects.
6666
If the inference is successful, an array of tensor pointers is returned.
@@ -85,7 +85,7 @@ An array of output tensor pointers.
8585

8686
> `protected` **forwardET**(`inputTensor`): `Promise`\<[`TensorPtr`](../interfaces/TensorPtr.md)[]\>
8787
88-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:23](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L23)
88+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:23](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/BaseModule.ts#L23)
8989

9090
Runs the model's forward method with the given input tensors.
9191
It returns the output tensors that mimic the structure of output from ExecuTorch.
@@ -114,7 +114,7 @@ Array of output tensors.
114114

115115
> **getInputShape**(`methodName`, `index`): `Promise`\<`number`[]\>
116116
117-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:34](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L34)
117+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:34](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/BaseModule.ts#L34)
118118

119119
Gets the input shape for a given method and index.
120120

@@ -148,7 +148,7 @@ The input shape as an array of numbers.
148148

149149
> **load**(`modelSource`, `onDownloadProgressCallback`): `Promise`\<`void`\>
150150
151-
Defined in: [packages/react-native-executorch/src/modules/general/ExecutorchModule.ts:21](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts#L21)
151+
Defined in: [packages/react-native-executorch/src/modules/general/ExecutorchModule.ts:22](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts#L22)
152152

153153
Loads the model, where `modelSource` is a string, number, or object that specifies the location of the model binary.
154154
Optionally accepts a download progress callback.

docs/docs/06-api-reference/classes/ImageEmbeddingsModule.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Class: ImageEmbeddingsModule
22

3-
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts:12](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts#L12)
3+
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts:13](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts#L13)
44

55
Module for generating image embeddings from input images.
66

@@ -28,7 +28,7 @@ Module for generating image embeddings from input images.
2828

2929
> **nativeModule**: `any` = `null`
3030
31-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:8](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L8)
31+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:8](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/BaseModule.ts#L8)
3232

3333
Native module instance
3434

@@ -42,7 +42,7 @@ Native module instance
4242

4343
> **delete**(): `void`
4444
45-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:41](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L41)
45+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:41](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/BaseModule.ts#L41)
4646

4747
Unloads the model from memory.
4848

@@ -60,7 +60,7 @@ Unloads the model from memory.
6060

6161
> **forward**(`imageSource`): `Promise`\<`Float32Array`\<`ArrayBufferLike`\>\>
6262
63-
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts:42](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts#L42)
63+
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts:50](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts#L50)
6464

6565
Executes the model's forward pass. Returns an embedding array for a given sentence.
6666

@@ -84,7 +84,7 @@ A Float32Array containing the image embeddings.
8484

8585
> `protected` **forwardET**(`inputTensor`): `Promise`\<[`TensorPtr`](../interfaces/TensorPtr.md)[]\>
8686
87-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:23](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L23)
87+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:23](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/BaseModule.ts#L23)
8888

8989
Runs the model's forward method with the given input tensors.
9090
It returns the output tensors that mimic the structure of output from ExecuTorch.
@@ -113,7 +113,7 @@ Array of output tensors.
113113

114114
> **getInputShape**(`methodName`, `index`): `Promise`\<`number`[]\>
115115
116-
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:34](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/BaseModule.ts#L34)
116+
Defined in: [packages/react-native-executorch/src/modules/BaseModule.ts:34](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/BaseModule.ts#L34)
117117

118118
Gets the input shape for a given method and index.
119119

@@ -147,7 +147,7 @@ The input shape as an array of numbers.
147147

148148
> **load**(`model`, `onDownloadProgressCallback`): `Promise`\<`void`\>
149149
150-
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts:19](https://github.com/software-mansion/react-native-executorch/blob/326d6344894d75625c600d5988666e215a32d466/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts#L19)
150+
Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts:20](https://github.com/software-mansion/react-native-executorch/blob/a6b2b6f4f1622166e3517338d42680655383a3be/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts#L20)
151151

152152
Loads the model, where `modelSource` is a string that specifies the location of the model binary.
153153

0 commit comments

Comments
 (0)