Skip to content

Commit edb1306

Browse files
feat: add various strategies for handling context window
1 parent 1ebccfe commit edb1306

6 files changed

Lines changed: 178 additions & 10 deletions

File tree

packages/react-native-executorch/src/constants/llmDefaults.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { ChatConfig, Message } from '../types/llm';
2+
import { MessageCountContextStrategy } from '../utils/llms/context_strategy/MessageCountContextStrategy';
23

34
/**
45
* Default system prompt used to guide the behavior of Large Language Models (LLMs).
@@ -48,5 +49,7 @@ export const DEFAULT_CONTEXT_WINDOW_LENGTH = 5;
4849
export const DEFAULT_CHAT_CONFIG: ChatConfig = {
4950
systemPrompt: DEFAULT_SYSTEM_PROMPT,
5051
initialMessageHistory: DEFAULT_MESSAGE_HISTORY,
51-
contextWindowLength: DEFAULT_CONTEXT_WINDOW_LENGTH,
52+
contextStrategy: new MessageCountContextStrategy(
53+
DEFAULT_CONTEXT_WINDOW_LENGTH
54+
),
5255
};

packages/react-native-executorch/src/controllers/LLMController.ts

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,29 @@ export class LLMController {
305305
}
306306

307307
public async sendMessage(message: string): Promise<string> {
308-
this.messageHistoryCallback([
308+
const updatedHistory = [
309309
...this._messageHistory,
310-
{ content: message, role: 'user' },
311-
]);
312-
313-
const messageHistoryWithPrompt: Message[] = [
314-
{ content: this.chatConfig.systemPrompt, role: 'system' },
315-
...this._messageHistory.slice(-this.chatConfig.contextWindowLength),
310+
{ content: message, role: 'user' as const },
316311
];
312+
this.messageHistoryCallback(updatedHistory);
313+
314+
const countTokensCallback = (messages: Message[]) => {
315+
const rendered = this.applyChatTemplate(
316+
messages,
317+
this.tokenizerConfig,
318+
this.toolsConfig?.tools,
319+
// eslint-disable-next-line camelcase
320+
{ tools_in_user_message: false, add_generation_prompt: true }
321+
);
322+
return this.nativeModule.getTokenCount(rendered);
323+
};
324+
325+
const messageHistoryWithPrompt =
326+
this.chatConfig.contextStrategy.buildContext(
327+
this.chatConfig.systemPrompt,
328+
updatedHistory,
329+
countTokensCallback
330+
);
317331

318332
const response = await this.generate(
319333
messageHistoryWithPrompt,

packages/react-native-executorch/src/types/llm.ts

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,13 @@ export type LLMTool = Object;
212212
*
213213
* @category Types
214214
* @property {Message[]} initialMessageHistory - An array of `Message` objects that represent the conversation history. This can be used to provide initial context to the model.
215-
* @property {number} contextWindowLength - The number of messages from the current conversation that the model will use to generate a response. The higher the number, the more context the model will have. Keep in mind that using larger context windows will result in longer inference time and higher memory usage.
216215
* @property {string} systemPrompt - Often used to tell the model what is its purpose, for example - "Be a helpful translator".
216+
* @property {ContextStrategy} contextStrategy - Defines a strategy for managing the conversation context window and message history.
217217
*/
218218
export interface ChatConfig {
219219
initialMessageHistory: Message[];
220-
contextWindowLength: number;
221220
systemPrompt: string;
221+
contextStrategy: ContextStrategy;
222222
}
223223

224224
/**
@@ -251,6 +251,26 @@ export interface GenerationConfig {
251251
batchTimeInterval?: number;
252252
}
253253

254+
/**
255+
* Defines a strategy for managing the conversation context window and message history.
256+
*
257+
* @category Types
258+
*/
259+
export interface ContextStrategy {
260+
/**
261+
* Constructs the final array of messages to be sent to the model for the current inference step.
262+
* * @param systemPrompt - The top-level instructions or persona assigned to the model.
263+
* @param history - The complete conversation history up to the current point.
264+
* @param getTokenCount - A callback function provided by the LLM controller that calculates the exact number of tokens a specific array of messages will consume once formatted.
265+
* @returns The optimized array of messages, ready to be processed by the model.
266+
*/
267+
buildContext(
268+
systemPrompt: string,
269+
history: Message[],
270+
getTokenCount: (messages: Message[]) => number
271+
): Message[];
272+
}
273+
254274
/**
255275
* Special tokens used in Large Language Models (LLMs).
256276
*
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import { DEFAULT_CONTEXT_WINDOW_LENGTH } from '../../../constants/llmDefaults';
2+
import { ContextStrategy, Message } from '../../../types/llm';
3+
4+
/**
5+
* A simple context strategy that retains a fixed number of the most recent messages.
6+
* This strategy trims the conversation history based purely on the message count.
7+
*
8+
* @category Utils
9+
*/
10+
export class MessageCountContextStrategy implements ContextStrategy {
11+
/**
12+
* Initializes the MessageCountContextStrategy.
13+
* * @param {number} windowLength - The maximum number of recent messages to retain in the context. Defaults to {@link DEFAULT_CONTEXT_WINDOW_LENGTH}.
14+
*/
15+
constructor(
16+
private readonly windowLength: number = DEFAULT_CONTEXT_WINDOW_LENGTH
17+
) {}
18+
19+
/**
20+
* Builds the context by slicing the history to retain only the most recent `windowLength` messages.
21+
*
22+
* @param {string} systemPrompt - The top-level instructions for the model.
23+
* @param {Message[]} history - The complete conversation history.
24+
* @param {(messages: Message[]) => number} _getTokenCount - Unused in this strategy.
25+
* @returns {Message[]} The truncated message history with the system prompt at the beginning.
26+
*/
27+
buildContext(
28+
systemPrompt: string,
29+
history: Message[],
30+
_getTokenCount: (messages: Message[]) => number
31+
): Message[] {
32+
return [
33+
{ content: systemPrompt, role: 'system' as const },
34+
...history.slice(-this.windowLength),
35+
];
36+
}
37+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import { ContextStrategy, Message } from '../../../types/llm';
2+
3+
/**
4+
* A context strategy that performs no filtering or trimming of the message history.
5+
* * This strategy is ideal when the developer wants to manually manage the conversation
6+
* context (e.g., using a custom RAG pipeline or external database) and just needs the
7+
* system prompt prepended to their pre-computed history.
8+
*
9+
* @category Utils
10+
*/
11+
export class NaiveContextStrategy implements ContextStrategy {
12+
/**
13+
* Builds the context by prepending the system prompt to the entire unfiltered history.
14+
*
15+
* @param {string} systemPrompt - The top-level instructions for the model.
16+
* @param {Message[]} history - The complete conversation history.
17+
* @param {(messages: Message[]) => number} _getTokenCount - Unused in this strategy.
18+
* @returns {Message[]} The unedited message history with the system prompt at the beginning.
19+
*/
20+
buildContext(
21+
systemPrompt: string,
22+
history: Message[],
23+
_getTokenCount: (messages: Message[]) => number
24+
): Message[] {
25+
return [{ content: systemPrompt, role: 'system' as const }, ...history];
26+
}
27+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import { ContextStrategy, Message } from '../../../types/llm';
2+
3+
/**
4+
* An advanced, token-aware context strategy that dynamically trims the message history
5+
* to ensure it fits within the model's physical context limits.
6+
* * This strategy calculates the exact token count of the formatted prompt. If the prompt
7+
* exceeds the allowed token budget (`maxTokens` - `bufferTokens`), it recursively
8+
* removes the oldest messages.
9+
*
10+
* @category Utils
11+
*/
12+
export class SlidingWindowContextStrategy implements ContextStrategy {
13+
/**
14+
* Initializes the SlidingWindowContextStrategy.
15+
* @param {number} maxTokens - The absolute maximum number of tokens the model can process (e.g., 4096).
16+
* @param {number} bufferTokens - The number of tokens to keep free for the model's generated response (e.g., 1000).
17+
* @param {boolean} allowOrphanedAssistantMessages - Whether to allow orphaned assistant messages when trimming the history. If false, the strategy will ensure that an assistant message is not left without its preceding user message.
18+
*/
19+
constructor(
20+
private maxTokens: number,
21+
private bufferTokens: number,
22+
private allowOrphanedAssistantMessages: boolean = false
23+
) {}
24+
25+
/**
26+
* Builds the context by recursively evicting the oldest messages until the total
27+
* token count is safely within the defined budget.
28+
*
29+
* @param {string} systemPrompt - The top-level instructions for the model.
30+
* @param {Message[]} history - The complete conversation history.
31+
* @param {(messages: Message[]) => number} getTokenCount - Callback to calculate the exact token count of the rendered template.
32+
* @returns {Message[]} The optimized message history guaranteed to fit the token budget.
33+
*/
34+
buildContext(
35+
systemPrompt: string,
36+
history: Message[],
37+
getTokenCount: (messages: Message[]) => number
38+
): Message[] {
39+
let localHistory = [...history];
40+
const tokenBudget = this.maxTokens - this.bufferTokens;
41+
42+
while (localHistory.length > 1) {
43+
const candidateContext: Message[] = [
44+
{ content: systemPrompt, role: 'system' as const },
45+
...localHistory,
46+
];
47+
48+
if (getTokenCount(candidateContext) <= tokenBudget) {
49+
return candidateContext;
50+
}
51+
52+
localHistory.shift();
53+
54+
if (!this.allowOrphanedAssistantMessages) {
55+
// Prevent leaving an orphaned "assistant" response
56+
if (localHistory.length > 0 && localHistory[0]?.role === 'assistant') {
57+
localHistory.shift();
58+
}
59+
}
60+
}
61+
62+
return [
63+
{ content: systemPrompt, role: 'system' as const },
64+
...localHistory,
65+
];
66+
}
67+
}

0 commit comments

Comments
 (0)