From 0bd2f5216b834a623a2e2b9209f301c7be78d48c Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 23 Apr 2026 18:41:29 -0700 Subject: [PATCH] Widen resolve_max_new_tokens parameters to int64_t and rename for clarity (#18917) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The second parameter was named `num_prompt_tokens` (int32_t) but all callers (TextLLMRunner, MultimodalRunner) actually pass `pos_` (int64_t), which represents the total number of occupied positions in the context window — not just the current prompt's tokens. - Rename `num_prompt_tokens` → `num_tokens_occupied` to match actual semantics - Widen both parameters from int32_t to int64_t to eliminate implicit narrowing conversions from int64_t callers - Use int64_t internally to avoid truncation during intermediate arithmetic - Update pybinding arg name, .pyi type stub, tests, and docs Reviewed By: larryliu0820 Differential Revision: D99769848 --- docs/source/llm/run-with-c-plus-plus.md | 4 +- extension/llm/runner/_llm_runner.pyi | 5 ++- extension/llm/runner/irunner.h | 32 +++++++++------ extension/llm/runner/pybindings.cpp | 2 +- .../runner/test/test_generation_config.cpp | 40 +++++++++---------- .../llm/runner/test/test_runner_pybindings.py | 8 ++++ 6 files changed, 53 insertions(+), 38 deletions(-) diff --git a/docs/source/llm/run-with-c-plus-plus.md b/docs/source/llm/run-with-c-plus-plus.md index 217afad847b..b6c6082c3a6 100644 --- a/docs/source/llm/run-with-c-plus-plus.md +++ b/docs/source/llm/run-with-c-plus-plus.md @@ -183,13 +183,13 @@ struct GenerationConfig { int32_t num_eos = 0; // Number of EOS tokens to add // Helper method to resolve the actual max_new_tokens based on constraints - int32_t resolve_max_new_tokens(int32_t max_context_len, int32_t num_prompt_tokens) const; + int32_t resolve_max_new_tokens(int64_t max_context_len, int64_t num_tokens_occupied) const; }; ``` The `resolve_max_new_tokens` method handles the logic of determining how many tokens can be generated based on: - The model's maximum context length -- The number of tokens in the prompt +- The number of token positions already occupied in the context window - The user-specified maximum sequence length and maximum new tokens ### Implementation Components diff --git a/extension/llm/runner/_llm_runner.pyi b/extension/llm/runner/_llm_runner.pyi index 20333578763..271cf1e1540 100644 --- a/extension/llm/runner/_llm_runner.pyi +++ b/extension/llm/runner/_llm_runner.pyi @@ -47,14 +47,15 @@ class GenerationConfig: ... def resolve_max_new_tokens( - self, max_context_len: int, num_prompt_tokens: int + self, max_context_len: int, num_tokens_occupied: int ) -> int: """ Resolve the maximum number of new tokens to generate based on constraints. Args: max_context_len: The maximum context length supported by the model - num_prompt_tokens: The number of tokens in the input prompt + num_tokens_occupied: The number of token positions already occupied + in the context window (e.g. pos after prefill) Returns: The resolved maximum number of new tokens to generate diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h index 0fcce1f37e4..bb7dd767fea 100644 --- a/extension/llm/runner/irunner.h +++ b/extension/llm/runner/irunner.h @@ -10,6 +10,7 @@ #pragma once +#include #include #include #include @@ -65,36 +66,41 @@ struct GenerationConfig { * * This method calculates the maximum number of new tokens that can be * generated considering both seq_len and max_new_tokens constraints, as well - * as the model's maximum context length and the number of tokens in the - * prompt. + * as the model's maximum context length and how many token positions are + * already occupied (e.g. by prior turns and the current prompt). * * @param max_context_len The maximum context length supported by the model - * @param num_prompt_tokens The number of tokens in the input prompt + * @param num_tokens_occupied The number of token positions already occupied + * in the context window (e.g. pos_ after prefill) * @return The resolved maximum number of new tokens to generate */ int32_t resolve_max_new_tokens( - int32_t max_context_len, - int32_t num_prompt_tokens) const { - int32_t result; + int64_t max_context_len, + int64_t num_tokens_occupied) const { + int64_t result; if (seq_len == -1 && max_new_tokens == -1) { - // Both are -1, use max context len minus prompt tokens - result = max_context_len - num_prompt_tokens; + // Both are -1, use max context len minus occupied tokens + result = max_context_len - num_tokens_occupied; } else if (seq_len == -1 && max_new_tokens != -1) { // Only max_new_tokens is specified - result = std::min(max_new_tokens, max_context_len - num_prompt_tokens); + result = std::min( + static_cast(max_new_tokens), + max_context_len - num_tokens_occupied); } else if (seq_len != -1 && max_new_tokens == -1) { // Only seq_len is specified - result = std::min(seq_len, max_context_len) - num_prompt_tokens; + result = std::min(static_cast(seq_len), max_context_len) - + num_tokens_occupied; } else { // Both are specified result = std::min( - std::min(seq_len, max_context_len) - num_prompt_tokens, - max_new_tokens); + std::min(static_cast(seq_len), max_context_len) - + num_tokens_occupied, + static_cast(max_new_tokens)); } // Ensure result is not negative - return std::max(0, result); + return static_cast(std::max(static_cast(0), result)); } }; diff --git a/extension/llm/runner/pybindings.cpp b/extension/llm/runner/pybindings.cpp index ecd49e6341a..3188b5390c4 100644 --- a/extension/llm/runner/pybindings.cpp +++ b/extension/llm/runner/pybindings.cpp @@ -297,7 +297,7 @@ PYBIND11_MODULE(_llm_runner, m) { "resolve_max_new_tokens", &GenerationConfig::resolve_max_new_tokens, py::arg("max_context_len"), - py::arg("num_prompt_tokens"), + py::arg("num_tokens_occupied"), "Resolve the maximum number of new tokens to generate based on constraints") .def("__repr__", [](const GenerationConfig& config) { return "