Fix double-subtraction of pos_ in TextLLMRunner::generate() (#18727)#18727
Fix double-subtraction of pos_ in TextLLMRunner::generate() (#18727)#18727meta-codesync[bot] merged 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18727
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 3 Pending, 3 Unrelated FailuresAs of commit 2332004 with merge base 19f7ff2 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D99742232. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Fixes token budget resolution in TextLLMRunner::generate() for multi-turn conversations when seq_len is set, aligning behavior with multimodal_runner so occupied KV-cache positions are correctly accounted for.
Changes:
- Stop pre-adjusting
max_context_lenbypos_; use the raw metadata value instead. - Resolve
max_new_tokensusing the full occupied position count (pos_after prefill), and tighten the max-context prefill guard accordingly. - Add a regression test covering the multi-turn +
seq_lencase.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| extension/llm/runner/text_llm_runner.cpp | Corrects max token resolution by using raw max_context_len and passing occupied positions (pos_) into resolve_max_new_tokens(). |
| extension/llm/runner/test/test_text_llm_runner.cpp | Adds regression coverage to ensure seq_len limits respect prior-turn pos_ occupancy. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Resolve max_new_tokens. pos_ now reflects all occupied positions | ||
| // (including prompt tokens just prefilled). | ||
| int max_new_tokens = | ||
| config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); | ||
| config.resolve_max_new_tokens(max_context_len, pos_); | ||
|
|
There was a problem hiding this comment.
GenerationConfig::resolve_max_new_tokens() takes int32_t parameters documented as num_prompt_tokens, but this call passes pos_ (an int64_t occupied-position count). This relies on implicit narrowing conversions and on a broader interpretation of the parameter than the API/docstring (and the pybinding arg name) suggests. Consider updating resolve_max_new_tokens to accept an int64_t occupied token count (or adding a new helper with clearer naming) and adjusting the documentation/bindings to avoid truncation risk and confusion.
Summary: When seq_len is set and pos_ > 0 (multi-turn conversations), max_context_len was pre-adjusted by subtracting pos_, but resolve_max_new_tokens then only subtracted num_prompt_tokens instead of the full occupied position count. This caused min(seq_len, max_context_len) to use a too-large max_context_len, producing more tokens than allowed by seq_len. Fix: use raw metadata value for max_context_len and pass pos_ (which includes prompt tokens after prefill) to resolve_max_new_tokens, matching multimodal_runner's behavior. Differential Revision: D99742232
6b8cca8 to
6deda58
Compare
Summary: When seq_len is set and pos_ > 0 (multi-turn conversations), max_context_len was pre-adjusted by subtracting pos_, but resolve_max_new_tokens then only subtracted num_prompt_tokens instead of the full occupied position count. This caused min(seq_len, max_context_len) to use a too-large max_context_len, producing more tokens than allowed by seq_len. Fix: use raw metadata value for max_context_len and pass pos_ (which includes prompt tokens after prefill) to resolve_max_new_tokens, matching multimodal_runner's behavior. Differential Revision: D99742232
6deda58 to
009b11d
Compare
Summary: When seq_len is set and pos_ > 0 (multi-turn conversations), max_context_len was pre-adjusted by subtracting pos_, but resolve_max_new_tokens then only subtracted num_prompt_tokens instead of the full occupied position count. This caused min(seq_len, max_context_len) to use a too-large max_context_len, producing more tokens than allowed by seq_len. Fix: use raw metadata value for max_context_len and pass pos_ (which includes prompt tokens after prefill) to resolve_max_new_tokens, matching multimodal_runner's behavior. Differential Revision: D99742232
009b11d to
784d607
Compare
Summary: Pull Request resolved: #18727 When seq_len is set and pos_ > 0 (multi-turn conversations), max_context_len was pre-adjusted by subtracting pos_, but resolve_max_new_tokens then only subtracted num_prompt_tokens instead of the full occupied position count. This caused min(seq_len, max_context_len) to use a too-large max_context_len, producing more tokens than allowed by seq_len. Fix: use raw metadata value for max_context_len and pass pos_ (which includes prompt tokens after prefill) to resolve_max_new_tokens, matching multimodal_runner's behavior. Differential Revision: D99742232
784d607 to
2332004
Compare
larryliu0820
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
…18727) Differential Revision: D99742232 Pull Request resolved: pytorch#18727
Summary:
When seq_len is set and pos_ > 0 (multi-turn conversations),
max_context_len was pre-adjusted by subtracting pos_, but
resolve_max_new_tokens then only subtracted num_prompt_tokens
instead of the full occupied position count. This caused
min(seq_len, max_context_len) to use a too-large max_context_len,
producing more tokens than allowed by seq_len.
Fix: use raw metadata value for max_context_len and pass pos_
(which includes prompt tokens after prefill) to
resolve_max_new_tokens, matching multimodal_runner's behavior.
Differential Revision: D99742232