Skip to content

Commit 8b7ba1d

Browse files
rework context strategies maContextLength
1 parent 14a6b9c commit 8b7ba1d

11 files changed

Lines changed: 53 additions & 22 deletions

File tree

apps/llm/app/llm/index.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function LLMScreen() {
3939
useEffect(() => {
4040
llm.configure({
4141
chatConfig: {
42-
contextStrategy: new SlidingWindowContextStrategy(2048, 512),
42+
contextStrategy: new SlidingWindowContextStrategy(512),
4343
},
4444
});
4545
// eslint-disable-next-line react-hooks/exhaustive-deps

packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,11 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
133133
synchronousHostFunction<&Model::setTopp>,
134134
"setTopp"));
135135

136+
addFunctions(JSI_EXPORT_FUNCTION(
137+
ModelHostObject<Model>,
138+
synchronousHostFunction<&Model::getMaxContextLength>,
139+
"getMaxContextLength"));
140+
136141
addFunctions(
137142
JSI_EXPORT_FUNCTION(ModelHostObject<Model>, unload, "unload"));
138143

packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,16 @@ void LLM::setTopp(float topp) {
146146
}
147147
runner->set_topp(topp);
148148
}
149+
150+
size_t LLM::getMaxContextLength() const {
151+
if (!runner || !runner->is_loaded()) {
152+
throw RnExecutorchError(
153+
RnExecutorchErrorCode::ModuleNotLoaded,
154+
"Can't get context length from a model that's not loaded");
155+
}
156+
return runner->get_max_context_length();
157+
}
158+
149159
void LLM::unload() noexcept { runner.reset(nullptr); }
150160

151161
} // namespace rnexecutorch::models::llm

packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class LLM : public BaseModel {
3131
void setTemperature(float temperature);
3232
void setTopp(float topp);
3333
void setTimeInterval(size_t timeInterval);
34+
size_t getMaxContextLength() const;
3435

3536
private:
3637
std::unique_ptr<example::Runner> runner;

packages/react-native-executorch/common/runner/runner.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,26 @@ void Runner::set_topp(float topp) noexcept {
342342
}
343343
}
344344

345+
int32_t Runner::get_max_context_length() const {
346+
if (!is_loaded()) {
347+
return static_cast<int32_t>(metadata_.at(kMaxContextLen));
348+
}
349+
return config_.max_context_length;
350+
}
351+
352+
int32_t Runner::count_text_tokens(const std::string &text) const {
353+
auto encodeResult =
354+
tokenizer_->encode(text, numOfAddedBoSTokens, numOfAddedEoSTokens);
355+
356+
if (!encodeResult.ok()) {
357+
throw rnexecutorch::RnExecutorchError(
358+
rnexecutorch::RnExecutorchErrorCode::TokenizerError,
359+
"Encoding failed during token count check.");
360+
}
361+
362+
return static_cast<int32_t>(encodeResult.get().size());
363+
}
364+
345365
int32_t Runner::resolve_max_new_tokens(int32_t num_prompt_tokens,
346366
int32_t max_seq_len,
347367
int32_t max_context_len,
@@ -368,17 +388,4 @@ int32_t Runner::resolve_max_new_tokens(int32_t num_prompt_tokens,
368388
return std::max(0, result);
369389
}
370390

371-
int32_t Runner::count_text_tokens(const std::string &text) const {
372-
auto encodeResult =
373-
tokenizer_->encode(text, numOfAddedBoSTokens, numOfAddedEoSTokens);
374-
375-
if (!encodeResult.ok()) {
376-
throw rnexecutorch::RnExecutorchError(
377-
rnexecutorch::RnExecutorchErrorCode::TokenizerError,
378-
"Encoding failed during token count check.");
379-
}
380-
381-
return static_cast<int32_t>(encodeResult.get().size());
382-
}
383-
384391
} // namespace example

packages/react-native-executorch/common/runner/runner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class Runner : public llm::IRunner {
5151
void set_temperature(float temperature) noexcept;
5252
void set_topp(float topp) noexcept;
5353
int32_t count_text_tokens(const std::string &text) const;
54+
int32_t get_max_context_length() const;
5455

5556
void stop() override;
5657
void reset() override;

packages/react-native-executorch/src/controllers/LLMController.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,12 @@ export class LLMController {
321321
);
322322
return this.nativeModule.countTextTokens(rendered);
323323
};
324-
324+
const maxContextLength = this.nativeModule.getMaxContextLength();
325325
const messageHistoryWithPrompt =
326326
this.chatConfig.contextStrategy.buildContext(
327327
this.chatConfig.systemPrompt,
328328
updatedHistory,
329+
maxContextLength,
329330
countTokensCallback
330331
);
331332

packages/react-native-executorch/src/types/llm.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,14 @@ export interface ContextStrategy {
261261
* Constructs the final array of messages to be sent to the model for the current inference step.
262262
* * @param systemPrompt - The top-level instructions or persona assigned to the model.
263263
* @param history - The complete conversation history up to the current point.
264+
* @param maxContextLength - The maximum number of tokens that the model can keep in the context.
264265
* @param getTokenCount - A callback function provided by the LLM controller that calculates the exact number of tokens a specific array of messages will consume once formatted.
265266
* @returns The optimized array of messages, ready to be processed by the model.
266267
*/
267268
buildContext(
268269
systemPrompt: string,
269270
history: Message[],
271+
maxContextLength: number,
270272
getTokenCount: (messages: Message[]) => number
271273
): Message[];
272274
}

packages/react-native-executorch/src/utils/llms/context_strategy/MessageCountContextStrategy.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ export class MessageCountContextStrategy implements ContextStrategy {
2222
* @param {string} systemPrompt - The top-level instructions for the model.
2323
* @param {Message[]} history - The complete conversation history.
2424
* @param {(messages: Message[]) => number} _getTokenCount - Unused in this strategy.
25+
* @param {number} _maxContextLength - Unused in this strategy.
2526
* @returns {Message[]} The truncated message history with the system prompt at the beginning.
2627
*/
2728
buildContext(
2829
systemPrompt: string,
2930
history: Message[],
31+
_maxContextLength: number,
3032
_getTokenCount: (messages: Message[]) => number
3133
): Message[] {
3234
return [

packages/react-native-executorch/src/utils/llms/context_strategy/NaiveContextStrategy.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ import { ContextStrategy, Message } from '../../../types/llm';
33
/**
44
* A context strategy that performs no filtering or trimming of the message history.
55
* * This strategy is ideal when the developer wants to manually manage the conversation
6-
* context (e.g., using a custom RAG pipeline or external database) and just needs the
7-
* system prompt prepended to their pre-computed history.
6+
* context.
87
*
98
* @category Utils
109
*/
@@ -14,12 +13,14 @@ export class NaiveContextStrategy implements ContextStrategy {
1413
*
1514
* @param {string} systemPrompt - The top-level instructions for the model.
1615
* @param {Message[]} history - The complete conversation history.
16+
* @param {number} _maxContextLength - Unused in this strategy.
1717
* @param {(messages: Message[]) => number} _getTokenCount - Unused in this strategy.
1818
* @returns {Message[]} The unedited message history with the system prompt at the beginning.
1919
*/
2020
buildContext(
2121
systemPrompt: string,
2222
history: Message[],
23+
_maxContextLength: number,
2324
_getTokenCount: (messages: Message[]) => number
2425
): Message[] {
2526
return [{ content: systemPrompt, role: 'system' as const }, ...history];

0 commit comments

Comments
 (0)