-
Notifications
You must be signed in to change notification settings - Fork 71
fix/feat!: LLMs context management #819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
fb26ec7
1ebccfe
edb1306
3ebcb6e
0f12695
56ff74c
d2eed3c
14a6b9c
8b7ba1d
587f773
3acba46
6db88a2
c176e20
dae9356
2b72f36
a6b2b6f
83110c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ microcontrollers | |
| notimestamps | ||
| seqs | ||
| smollm | ||
| llms | ||
| qwen | ||
| XNNPACK | ||
| EFFICIENTNET | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -368,4 +368,17 @@ int32_t Runner::resolve_max_new_tokens(int32_t num_prompt_tokens, | |
| return std::max(0, result); | ||
| } | ||
|
|
||
| int32_t Runner::count_text_tokens(const std::string &text) const { | ||
| auto encodeResult = | ||
| tokenizer_->encode(text, numOfAddedBoSTokens, numOfAddedEoSTokens); | ||
|
|
||
| if (!encodeResult.ok()) { | ||
| throw rnexecutorch::RnExecutorchError( | ||
| rnexecutorch::RnExecutorchErrorCode::TokenizerError, | ||
| "Encoding failed during token count check."); | ||
| } | ||
|
|
||
| return static_cast<int32_t>(encodeResult.get().size()); | ||
|
msluszniak marked this conversation as resolved.
Outdated
|
||
| } | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if calling encoder specifically to get only the size of encoded text is the most efficient way to get the size. Can we compute this during encoding phase?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand this correctly that's not possible for 2 reasons:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively, maybe we could implement these context strategies in the c++ code, but as a result we give no flexibility to the user (?) Also:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to bother ourselves with computational complexity here, this is just tokenizer encoding which is very cheap and only performed once for the given prompt and so only impacts Time To First Token |
||
| } // namespace example | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| import { DEFAULT_CONTEXT_WINDOW_LENGTH } from '../../../constants/llmDefaults'; | ||
| import { ContextStrategy, Message } from '../../../types/llm'; | ||
|
|
||
| /** | ||
| * A simple context strategy that retains a fixed number of the most recent messages. | ||
| * This strategy trims the conversation history based purely on the message count. | ||
| * | ||
| * @category Utils | ||
| */ | ||
| export class MessageCountContextStrategy implements ContextStrategy { | ||
| /** | ||
| * Initializes the MessageCountContextStrategy. | ||
| * * @param {number} windowLength - The maximum number of recent messages to retain in the context. Defaults to {@link DEFAULT_CONTEXT_WINDOW_LENGTH}. | ||
| */ | ||
| constructor( | ||
| private readonly windowLength: number = DEFAULT_CONTEXT_WINDOW_LENGTH | ||
| ) {} | ||
|
|
||
| /** | ||
| * Builds the context by slicing the history to retain only the most recent `windowLength` messages. | ||
| * | ||
| * @param {string} systemPrompt - The top-level instructions for the model. | ||
| * @param {Message[]} history - The complete conversation history. | ||
| * @param {(messages: Message[]) => number} _getTokenCount - Unused in this strategy. | ||
| * @returns {Message[]} The truncated message history with the system prompt at the beginning. | ||
| */ | ||
| buildContext( | ||
| systemPrompt: string, | ||
| history: Message[], | ||
| _getTokenCount: (messages: Message[]) => number | ||
| ): Message[] { | ||
| return [ | ||
| { content: systemPrompt, role: 'system' as const }, | ||
| ...history.slice(-this.windowLength), | ||
| ]; | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| import { ContextStrategy, Message } from '../../../types/llm'; | ||
|
|
||
| /** | ||
| * A context strategy that performs no filtering or trimming of the message history. | ||
| * * This strategy is ideal when the developer wants to manually manage the conversation | ||
| * context (e.g., using a custom RAG pipeline or external database) and just needs the | ||
| * system prompt prepended to their pre-computed history. | ||
| * | ||
| * @category Utils | ||
| */ | ||
| export class NaiveContextStrategy implements ContextStrategy { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. imo sth like NoOpContextStrategy is better, Naive suggest that there is some simple logic underneath |
||
| /** | ||
| * Builds the context by prepending the system prompt to the entire unfiltered history. | ||
| * | ||
| * @param {string} systemPrompt - The top-level instructions for the model. | ||
| * @param {Message[]} history - The complete conversation history. | ||
| * @param {(messages: Message[]) => number} _getTokenCount - Unused in this strategy. | ||
| * @returns {Message[]} The unedited message history with the system prompt at the beginning. | ||
| */ | ||
| buildContext( | ||
| systemPrompt: string, | ||
| history: Message[], | ||
| _getTokenCount: (messages: Message[]) => number | ||
| ): Message[] { | ||
| return [{ content: systemPrompt, role: 'system' as const }, ...history]; | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| import { ContextStrategy, Message } from '../../../types/llm'; | ||
|
|
||
| /** | ||
| * An advanced, token-aware context strategy that dynamically trims the message history | ||
| * to ensure it fits within the model's physical context limits. | ||
| * * This strategy calculates the exact token count of the formatted prompt. If the prompt | ||
| * exceeds the allowed token budget (`maxTokens` - `bufferTokens`), it recursively | ||
| * removes the oldest messages. | ||
| * | ||
| * @category Utils | ||
| */ | ||
| export class SlidingWindowContextStrategy implements ContextStrategy { | ||
| /** | ||
| * Initializes the SlidingWindowContextStrategy. | ||
| * @param {number} maxTokens - The absolute maximum number of tokens the model can process (e.g., 4096). | ||
| * @param {number} bufferTokens - The number of tokens to keep free for the model's generated response (e.g., 1000). | ||
| * @param {boolean} allowOrphanedAssistantMessages - Whether to allow orphaned assistant messages when trimming the history. If false, the strategy will ensure that an assistant message is not left without its preceding user message. | ||
| */ | ||
| constructor( | ||
| private maxTokens: number, | ||
| private bufferTokens: number, | ||
| private allowOrphanedAssistantMessages: boolean = false | ||
| ) {} | ||
|
|
||
| /** | ||
| * Builds the context by recursively evicting the oldest messages until the total | ||
| * token count is safely within the defined budget. | ||
| * | ||
| * @param {string} systemPrompt - The top-level instructions for the model. | ||
| * @param {Message[]} history - The complete conversation history. | ||
| * @param {(messages: Message[]) => number} getTokenCount - Callback to calculate the exact token count of the rendered template. | ||
| * @returns {Message[]} The optimized message history guaranteed to fit the token budget. | ||
| */ | ||
| buildContext( | ||
| systemPrompt: string, | ||
| history: Message[], | ||
| getTokenCount: (messages: Message[]) => number | ||
| ): Message[] { | ||
| let localHistory = [...history]; | ||
| const tokenBudget = this.maxTokens - this.bufferTokens; | ||
|
|
||
| while (localHistory.length > 1) { | ||
| const candidateContext: Message[] = [ | ||
| { content: systemPrompt, role: 'system' as const }, | ||
| ...localHistory, | ||
| ]; | ||
|
|
||
| if (getTokenCount(candidateContext) <= tokenBudget) { | ||
| return candidateContext; | ||
| } | ||
|
|
||
| localHistory.shift(); | ||
|
|
||
| if (!this.allowOrphanedAssistantMessages) { | ||
| // Prevent leaving an orphaned "assistant" response | ||
| if (localHistory.length > 0 && localHistory[0]?.role === 'assistant') { | ||
| localHistory.shift(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return [ | ||
| { content: systemPrompt, role: 'system' as const }, | ||
| ...localHistory, | ||
| ]; | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| export { MessageCountContextStrategy } from './MessageCountContextStrategy'; | ||
| export { SlidingWindowContextStrategy } from './SlidingWindowContextStrategy'; | ||
| export { NaiveContextStrategy } from './NaiveContextStrategy'; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.