fix/feat!: LLMs context management#819
Conversation
|
I will update docs after the code gets approved |
| int32_t Runner::count_text_tokens(const std::string &text) const { | ||
| auto encodeResult = | ||
| tokenizer_->encode(text, numOfAddedBoSTokens, numOfAddedEoSTokens); | ||
|
|
||
| if (!encodeResult.ok()) { | ||
| throw rnexecutorch::RnExecutorchError( | ||
| rnexecutorch::RnExecutorchErrorCode::TokenizerError, | ||
| "Encoding failed during token count check."); | ||
| } | ||
|
|
||
| return static_cast<int32_t>(encodeResult.get().size()); | ||
| } | ||
|
|
There was a problem hiding this comment.
I'm wondering if calling encoder specifically to get only the size of encoded text is the most efficient way to get the size. Can we compute this during encoding phase?
There was a problem hiding this comment.
If I understand this correctly that's not possible for 2 reasons:
- I need to calculate all tokens in the whole messages history + new message (so I don't know number of tokens for new message)
- I can't really use token counts from the runner reliably, because it counts reasoning tokens as well for reasoning models (and later on, this reasoning tokens are not included in jinja template) - it could lead to some discrepancies
There was a problem hiding this comment.
Alternatively, maybe we could implement these context strategies in the c++ code, but as a result we give no flexibility to the user (?)
Also:
- the context management strategy is configurable, so it can be switched on/off oraz replaced with other strategy - by default, we do not use sliding window, but the strategy based on number of messages (like it was before)
- I think that encoding phase should not take that long
There was a problem hiding this comment.
We don't need to bother ourselves with computational complexity here, this is just tokenizer encoding which is very cheap and only performed once for the given prompt and so only impacts Time To First Token
msluszniak
left a comment
There was a problem hiding this comment.
Please also resolve conflicts
|
I changed only docs for |
This comment was marked as resolved.
This comment was marked as resolved.
msluszniak
left a comment
There was a problem hiding this comment.
In a separate PR I would and TS code snippet with llm configuration.
| void LLM::reset() { | ||
| if (!runner || !runner->is_loaded()) { | ||
| throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, | ||
| "Can't interrupt a model that's not loaded"); |
There was a problem hiding this comment.
| "Can't interrupt a model that's not loaded"); | |
| "Can't reset a model that's not loaded"); |
| * | ||
| * @category Utils | ||
| */ | ||
| export class NaiveContextStrategy implements ContextStrategy { |
There was a problem hiding this comment.
imo sth like NoOpContextStrategy is better, Naive suggest that there is some simple logic underneath
## Description This PR fixes few bugs related to the LLMs, caused by mixing two approaches - functional (as we pass whole messages history each time) and stateful (as we keep `pos_` in the runner, representing at which position the KV cache is), which resulted in 3 bugs: - broken KV cache for reasoning models - in the runner, we counted tokens generated for the reasoning and included these in KV cache (`pos_ += num_generated_tokens`), but in next turns, `jinja template` removed these reasoning tokens from the messages history - as a result, KV-cache was incoherent - duplicated tokens in KV cache - we were passing whole messages history to the runner (functional approach), but we were also appending all tokens (prompt and generated) to the KV cache (which position is represented by `pos_`) - as a result tokens were "duplicated" in the KV cache and we were running out of available tokens very fast (exceeding `context_window_length`) - stateful TS functional API - even though our `generate()` method is called functional, it kept internal state in the runner (e. g. `pos_`) These bugs were fixed by resetting the runner before each generation, which makes it truly functional - old messages are prefilled and the KV cache can be still used during generation phase. Additionally, this PR adds `ContextStrategy` to `ChatConfig` interface, so now it's possible to define (or use one of already implemented) strategy for managing context (e. g. naive, message count based, sliding window) - it gives us more flexibility and user can decide what's best for their use case. From now on, `SlidingWindowContextStrategy` is also configured as the default one. ### Introduces a breaking change? - [x] Yes - [ ] No These changes will not break anything until max number of messages is not modified (I removed `contextWindowLength` from `ChatConfig` and replaced it with `contextStrategy`) ### Type of change - [x] Bug fix (change which fixes an issue) - [x] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [ ] Other (chores, tests, code style improvements etc.) ### Tested on - [x] iOS - [x] Android ### Testing instructions Run example llm app, open executorch logs (`adb logcat | grep -i "executorch"` for example) and see if numbers of tokens are properly aligned and if `pos_` is correct. To test different context management strategies, change `contextStrategy` in llm app and modify model configuration. ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues #776 ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [x] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes Position in KV cache, number of prompt tokens and number of generated tokens for both non-reasoning and reasoning models BEFORE changes. LLAMA 3.2 1B SPINQUANT (without reasoning) | pos_ | Prompt tokens | Generated tokens | |------------------|---------------|------------------| | 0 | 335 | 269 | | 604=269+335 | 872 | 372 | | 1848=604+872+372 | 1513 | CRASH | QWEN 3.0 0.6B QUANTIZED (with reasoning) | pos_ | Prompt tokens | Generated tokens | |------------------|---------------|------------------| | 0 | 309 | 457 | | 766=309+457 | 617 (<766!) | 192 | | 1575=766+617+192 | 925 (<1575!) | CRASH |
Description
This PR fixes few bugs related to the LLMs, caused by mixing two approaches - functional (as we pass whole messages history each time) and stateful (as we keep
pos_in the runner, representing at which position the KV cache is), which resulted in 3 bugs:pos_ += num_generated_tokens), but in next turns,jinja templateremoved these reasoning tokens from the messages history - as a result, KV-cache was incoherentpos_) - as a result tokens were "duplicated" in the KV cache and we were running out of available tokens very fast (exceedingcontext_window_length)generate()method is called functional, it kept internal state in the runner (e. g.pos_)These bugs were fixed by resetting the runner before each generation, which makes it truly functional - old messages are prefilled and the KV cache can be still used during generation phase.
Additionally, this PR adds
ContextStrategytoChatConfiginterface, so now it's possible to define (or use one of already implemented) strategy for managing context (e. g. naive, message count based, sliding window) - it gives us more flexibility and user can decide what's best for their use case. From now on,SlidingWindowContextStrategyis also configured as the default one.Introduces a breaking change?
These changes will not break anything until max number of messages is not modified (I removed
contextWindowLengthfromChatConfigand replaced it withcontextStrategy)Type of change
Tested on
Testing instructions
Run example llm app, open executorch logs (
adb logcat | grep -i "executorch"for example) and see if numbers of tokens are properly aligned and ifpos_is correct.To test different context management strategies, change
contextStrategyin llm app and modify model configuration.Screenshots
Related issues
#776
Checklist
Additional notes
Position in KV cache, number of prompt tokens and number of generated tokens for both non-reasoning and reasoning models BEFORE changes.
LLAMA 3.2 1B SPINQUANT (without reasoning)
QWEN 3.0 0.6B QUANTIZED (with reasoning)