Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cspell-wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ microcontrollers
notimestamps
seqs
smollm
llms
qwen
XNNPACK
EFFICIENTNET
Expand Down
15 changes: 14 additions & 1 deletion apps/llm/app/llm/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ import {
View,
} from 'react-native';
import SendIcon from '../../assets/icons/send_icon.svg';
import { useLLM, LLAMA3_2_1B_SPINQUANT } from 'react-native-executorch';
import {
useLLM,
LLAMA3_2_1B_SPINQUANT,
SlidingWindowContextStrategy,
} from 'react-native-executorch';
import PauseIcon from '../../assets/icons/pause_icon.svg';
import ColorPalette from '../../colors';
import Messages from '../../components/Messages';
Expand All @@ -32,6 +36,15 @@ function LLMScreen() {

const llm = useLLM({ model: LLAMA3_2_1B_SPINQUANT });

useEffect(() => {
llm.configure({
chatConfig: {
contextStrategy: new SlidingWindowContextStrategy(2048, 512),
},
});
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);

useEffect(() => {
if (llm.error) {
console.log('LLM error:', llm.error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
synchronousHostFunction<&Model::getPromptTokenCount>,
"getPromptTokenCount"));

addFunctions(JSI_EXPORT_FUNCTION(
ModelHostObject<Model>,
synchronousHostFunction<&Model::countTextTokens>, "countTextTokens"));

addFunctions(
JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
synchronousHostFunction<&Model::setCountInterval>,
Expand All @@ -131,6 +135,10 @@ template <typename Model> class ModelHostObject : public JsiHostObject {

addFunctions(
JSI_EXPORT_FUNCTION(ModelHostObject<Model>, unload, "unload"));

addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
synchronousHostFunction<&Model::reset>,
"reset"));
}

if constexpr (meta::SameAs<Model, models::text_to_image::TextToImage>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ void LLM::interrupt() {
runner->stop();
}

void LLM::reset() {
if (!runner || !runner->is_loaded()) {
throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded,
"Can't interrupt a model that's not loaded");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Can't interrupt a model that's not loaded");
"Can't reset a model that's not loaded");

}
runner->reset();
}

size_t LLM::getGeneratedTokenCount() const noexcept {
if (!runner || !runner->is_loaded()) {
return 0;
Expand All @@ -78,6 +86,15 @@ size_t LLM::getPromptTokenCount() const noexcept {
return runner->stats_.num_prompt_tokens;
}

size_t LLM::countTextTokens(std::string text) const {
if (!runner || !runner->is_loaded()) {
throw RnExecutorchError(
RnExecutorchErrorCode::ModuleNotLoaded,
"Can't count tokens from a model that's not loaded");
}
return runner->count_text_tokens(text);
}

size_t LLM::getMemoryLowerBound() const noexcept {
return memorySizeLowerBound;
}
Expand Down Expand Up @@ -116,7 +133,7 @@ void LLM::setTemperature(float temperature) {
"Temperature must be non-negative");
}
runner->set_temperature(temperature);
};
}

void LLM::setTopp(float topp) {
if (!runner || !runner->is_loaded()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ class LLM : public BaseModel {
std::string generate(std::string input,
std::shared_ptr<jsi::Function> callback);
void interrupt();
void reset();
void unload() noexcept;
size_t getGeneratedTokenCount() const noexcept;
size_t getPromptTokenCount() const noexcept;
size_t countTextTokens(std::string text) const;
Comment thread
msluszniak marked this conversation as resolved.
Outdated
size_t getMemoryLowerBound() const noexcept;
void setCountInterval(size_t countInterval);
void setTemperature(float temperature);
Expand Down
13 changes: 13 additions & 0 deletions packages/react-native-executorch/common/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,17 @@ int32_t Runner::resolve_max_new_tokens(int32_t num_prompt_tokens,
return std::max(0, result);
}

int32_t Runner::count_text_tokens(const std::string &text) const {
auto encodeResult =
tokenizer_->encode(text, numOfAddedBoSTokens, numOfAddedEoSTokens);

if (!encodeResult.ok()) {
throw rnexecutorch::RnExecutorchError(
rnexecutorch::RnExecutorchErrorCode::TokenizerError,
"Encoding failed during token count check.");
}

return static_cast<int32_t>(encodeResult.get().size());
Comment thread
msluszniak marked this conversation as resolved.
Outdated
}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if calling encoder specifically to get only the size of encoded text is the most efficient way to get the size. Can we compute this during encoding phase?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this correctly that's not possible for 2 reasons:

  • I need to calculate all tokens in the whole messages history + new message (so I don't know number of tokens for new message)
  • I can't really use token counts from the runner reliably, because it counts reasoning tokens as well for reasoning models (and later on, this reasoning tokens are not included in jinja template) - it could lead to some discrepancies

Copy link
Copy Markdown
Contributor Author

@mateuszlampert mateuszlampert Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, maybe we could implement these context strategies in the c++ code, but as a result we give no flexibility to the user (?)

Also:

  • the context management strategy is configurable, so it can be switched on/off oraz replaced with other strategy - by default, we do not use sliding window, but the strategy based on number of messages (like it was before)
  • I think that encoding phase should not take that long

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to bother ourselves with computational complexity here, this is just tokenizer encoding which is very cheap and only performed once for the given prompt and so only impacts Time To First Token

} // namespace example
1 change: 1 addition & 0 deletions packages/react-native-executorch/common/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Runner : public llm::IRunner {
void set_time_interval(size_t time_interval);
void set_temperature(float temperature) noexcept;
void set_topp(float topp) noexcept;
int32_t count_text_tokens(const std::string &text) const;

void stop() override;
void reset() override;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { ChatConfig, Message } from '../types/llm';
import { MessageCountContextStrategy } from '../utils/llms/context_strategy/MessageCountContextStrategy';

/**
* Default system prompt used to guide the behavior of Large Language Models (LLMs).
Expand Down Expand Up @@ -48,5 +49,7 @@ export const DEFAULT_CONTEXT_WINDOW_LENGTH = 5;
export const DEFAULT_CHAT_CONFIG: ChatConfig = {
systemPrompt: DEFAULT_SYSTEM_PROMPT,
initialMessageHistory: DEFAULT_MESSAGE_HISTORY,
contextWindowLength: DEFAULT_CONTEXT_WINDOW_LENGTH,
contextStrategy: new MessageCountContextStrategy(
DEFAULT_CONTEXT_WINDOW_LENGTH
),
};
29 changes: 22 additions & 7 deletions packages/react-native-executorch/src/controllers/LLMController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ export class LLMController {
}
try {
this.isGeneratingCallback(true);
this.nativeModule.reset();
const response = await this.nativeModule.generate(input, this.onToken);
return this.filterSpecialTokens(response);
} catch (e) {
Expand Down Expand Up @@ -304,15 +305,29 @@ export class LLMController {
}

public async sendMessage(message: string): Promise<string> {
this.messageHistoryCallback([
const updatedHistory = [
...this._messageHistory,
{ content: message, role: 'user' },
]);

const messageHistoryWithPrompt: Message[] = [
{ content: this.chatConfig.systemPrompt, role: 'system' },
...this._messageHistory.slice(-this.chatConfig.contextWindowLength),
{ content: message, role: 'user' as const },
];
this.messageHistoryCallback(updatedHistory);

const countTokensCallback = (messages: Message[]) => {
const rendered = this.applyChatTemplate(
messages,
this.tokenizerConfig,
this.toolsConfig?.tools,
// eslint-disable-next-line camelcase
{ tools_in_user_message: false, add_generation_prompt: true }
);
return this.nativeModule.countTextTokens(rendered);
};

const messageHistoryWithPrompt =
this.chatConfig.contextStrategy.buildContext(
this.chatConfig.systemPrompt,
updatedHistory,
countTokensCallback
);

const response = await this.generate(
messageHistoryWithPrompt,
Expand Down
1 change: 1 addition & 0 deletions packages/react-native-executorch/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ export * from './modules/general/ExecutorchModule';
// utils
export * from './utils/ResourceFetcher';
export * from './utils/llm';
export * from './utils/llms/context_strategy';

// types
export * from './types/objectDetection';
Expand Down
24 changes: 22 additions & 2 deletions packages/react-native-executorch/src/types/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,13 @@ export type LLMTool = Object;
*
* @category Types
* @property {Message[]} initialMessageHistory - An array of `Message` objects that represent the conversation history. This can be used to provide initial context to the model.
* @property {number} contextWindowLength - The number of messages from the current conversation that the model will use to generate a response. The higher the number, the more context the model will have. Keep in mind that using larger context windows will result in longer inference time and higher memory usage.
* @property {string} systemPrompt - Often used to tell the model what is its purpose, for example - "Be a helpful translator".
* @property {ContextStrategy} contextStrategy - Defines a strategy for managing the conversation context window and message history.
*/
export interface ChatConfig {
initialMessageHistory: Message[];
contextWindowLength: number;
systemPrompt: string;
contextStrategy: ContextStrategy;
}

/**
Expand Down Expand Up @@ -251,6 +251,26 @@ export interface GenerationConfig {
batchTimeInterval?: number;
}

/**
* Defines a strategy for managing the conversation context window and message history.
*
* @category Types
*/
export interface ContextStrategy {
/**
* Constructs the final array of messages to be sent to the model for the current inference step.
* * @param systemPrompt - The top-level instructions or persona assigned to the model.
* @param history - The complete conversation history up to the current point.
* @param getTokenCount - A callback function provided by the LLM controller that calculates the exact number of tokens a specific array of messages will consume once formatted.
* @returns The optimized array of messages, ready to be processed by the model.
*/
buildContext(
systemPrompt: string,
history: Message[],
getTokenCount: (messages: Message[]) => number
): Message[];
}

/**
* Special tokens used in Large Language Models (LLMs).
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { DEFAULT_CONTEXT_WINDOW_LENGTH } from '../../../constants/llmDefaults';
import { ContextStrategy, Message } from '../../../types/llm';

/**
* A simple context strategy that retains a fixed number of the most recent messages.
* This strategy trims the conversation history based purely on the message count.
*
* @category Utils
*/
export class MessageCountContextStrategy implements ContextStrategy {
/**
* Initializes the MessageCountContextStrategy.
* * @param {number} windowLength - The maximum number of recent messages to retain in the context. Defaults to {@link DEFAULT_CONTEXT_WINDOW_LENGTH}.
*/
constructor(
private readonly windowLength: number = DEFAULT_CONTEXT_WINDOW_LENGTH
) {}

/**
* Builds the context by slicing the history to retain only the most recent `windowLength` messages.
*
* @param {string} systemPrompt - The top-level instructions for the model.
* @param {Message[]} history - The complete conversation history.
* @param {(messages: Message[]) => number} _getTokenCount - Unused in this strategy.
* @returns {Message[]} The truncated message history with the system prompt at the beginning.
*/
buildContext(
systemPrompt: string,
history: Message[],
_getTokenCount: (messages: Message[]) => number
): Message[] {
return [
{ content: systemPrompt, role: 'system' as const },
...history.slice(-this.windowLength),
];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { ContextStrategy, Message } from '../../../types/llm';

/**
* A context strategy that performs no filtering or trimming of the message history.
* * This strategy is ideal when the developer wants to manually manage the conversation
* context (e.g., using a custom RAG pipeline or external database) and just needs the
* system prompt prepended to their pre-computed history.
*
* @category Utils
*/
export class NaiveContextStrategy implements ContextStrategy {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo sth like NoOpContextStrategy is better, Naive suggest that there is some simple logic underneath

/**
* Builds the context by prepending the system prompt to the entire unfiltered history.
*
* @param {string} systemPrompt - The top-level instructions for the model.
* @param {Message[]} history - The complete conversation history.
* @param {(messages: Message[]) => number} _getTokenCount - Unused in this strategy.
* @returns {Message[]} The unedited message history with the system prompt at the beginning.
*/
buildContext(
systemPrompt: string,
history: Message[],
_getTokenCount: (messages: Message[]) => number
): Message[] {
return [{ content: systemPrompt, role: 'system' as const }, ...history];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { ContextStrategy, Message } from '../../../types/llm';

/**
* An advanced, token-aware context strategy that dynamically trims the message history
* to ensure it fits within the model's physical context limits.
* * This strategy calculates the exact token count of the formatted prompt. If the prompt
* exceeds the allowed token budget (`maxTokens` - `bufferTokens`), it recursively
* removes the oldest messages.
*
* @category Utils
*/
export class SlidingWindowContextStrategy implements ContextStrategy {
/**
* Initializes the SlidingWindowContextStrategy.
* @param {number} maxTokens - The absolute maximum number of tokens the model can process (e.g., 4096).
* @param {number} bufferTokens - The number of tokens to keep free for the model's generated response (e.g., 1000).
* @param {boolean} allowOrphanedAssistantMessages - Whether to allow orphaned assistant messages when trimming the history. If false, the strategy will ensure that an assistant message is not left without its preceding user message.
*/
constructor(
private maxTokens: number,
private bufferTokens: number,
private allowOrphanedAssistantMessages: boolean = false
) {}

/**
* Builds the context by recursively evicting the oldest messages until the total
* token count is safely within the defined budget.
*
* @param {string} systemPrompt - The top-level instructions for the model.
* @param {Message[]} history - The complete conversation history.
* @param {(messages: Message[]) => number} getTokenCount - Callback to calculate the exact token count of the rendered template.
* @returns {Message[]} The optimized message history guaranteed to fit the token budget.
*/
buildContext(
systemPrompt: string,
history: Message[],
getTokenCount: (messages: Message[]) => number
): Message[] {
let localHistory = [...history];
const tokenBudget = this.maxTokens - this.bufferTokens;

while (localHistory.length > 1) {
const candidateContext: Message[] = [
{ content: systemPrompt, role: 'system' as const },
...localHistory,
];

if (getTokenCount(candidateContext) <= tokenBudget) {
return candidateContext;
}

localHistory.shift();

if (!this.allowOrphanedAssistantMessages) {
// Prevent leaving an orphaned "assistant" response
if (localHistory.length > 0 && localHistory[0]?.role === 'assistant') {
localHistory.shift();
}
}
}

return [
{ content: systemPrompt, role: 'system' as const },
...localHistory,
];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export { MessageCountContextStrategy } from './MessageCountContextStrategy';
export { SlidingWindowContextStrategy } from './SlidingWindowContextStrategy';
export { NaiveContextStrategy } from './NaiveContextStrategy';
Loading