diff --git a/docs/source/llm/run-with-c-plus-plus.md b/docs/source/llm/run-with-c-plus-plus.md index 217afad847b..b6c6082c3a6 100644 --- a/docs/source/llm/run-with-c-plus-plus.md +++ b/docs/source/llm/run-with-c-plus-plus.md @@ -183,13 +183,13 @@ struct GenerationConfig { int32_t num_eos = 0; // Number of EOS tokens to add // Helper method to resolve the actual max_new_tokens based on constraints - int32_t resolve_max_new_tokens(int32_t max_context_len, int32_t num_prompt_tokens) const; + int32_t resolve_max_new_tokens(int64_t max_context_len, int64_t num_tokens_occupied) const; }; ``` The `resolve_max_new_tokens` method handles the logic of determining how many tokens can be generated based on: - The model's maximum context length -- The number of tokens in the prompt +- The number of token positions already occupied in the context window - The user-specified maximum sequence length and maximum new tokens ### Implementation Components diff --git a/extension/llm/runner/_llm_runner.pyi b/extension/llm/runner/_llm_runner.pyi index 20333578763..271cf1e1540 100644 --- a/extension/llm/runner/_llm_runner.pyi +++ b/extension/llm/runner/_llm_runner.pyi @@ -47,14 +47,15 @@ class GenerationConfig: ... def resolve_max_new_tokens( - self, max_context_len: int, num_prompt_tokens: int + self, max_context_len: int, num_tokens_occupied: int ) -> int: """ Resolve the maximum number of new tokens to generate based on constraints. Args: max_context_len: The maximum context length supported by the model - num_prompt_tokens: The number of tokens in the input prompt + num_tokens_occupied: The number of token positions already occupied + in the context window (e.g. pos after prefill) Returns: The resolved maximum number of new tokens to generate diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h index 0fcce1f37e4..bb7dd767fea 100644 --- a/extension/llm/runner/irunner.h +++ b/extension/llm/runner/irunner.h @@ -10,6 +10,7 @@ #pragma once +#include #include #include #include @@ -65,36 +66,41 @@ struct GenerationConfig { * * This method calculates the maximum number of new tokens that can be * generated considering both seq_len and max_new_tokens constraints, as well - * as the model's maximum context length and the number of tokens in the - * prompt. + * as the model's maximum context length and how many token positions are + * already occupied (e.g. by prior turns and the current prompt). * * @param max_context_len The maximum context length supported by the model - * @param num_prompt_tokens The number of tokens in the input prompt + * @param num_tokens_occupied The number of token positions already occupied + * in the context window (e.g. pos_ after prefill) * @return The resolved maximum number of new tokens to generate */ int32_t resolve_max_new_tokens( - int32_t max_context_len, - int32_t num_prompt_tokens) const { - int32_t result; + int64_t max_context_len, + int64_t num_tokens_occupied) const { + int64_t result; if (seq_len == -1 && max_new_tokens == -1) { - // Both are -1, use max context len minus prompt tokens - result = max_context_len - num_prompt_tokens; + // Both are -1, use max context len minus occupied tokens + result = max_context_len - num_tokens_occupied; } else if (seq_len == -1 && max_new_tokens != -1) { // Only max_new_tokens is specified - result = std::min(max_new_tokens, max_context_len - num_prompt_tokens); + result = std::min( + static_cast(max_new_tokens), + max_context_len - num_tokens_occupied); } else if (seq_len != -1 && max_new_tokens == -1) { // Only seq_len is specified - result = std::min(seq_len, max_context_len) - num_prompt_tokens; + result = std::min(static_cast(seq_len), max_context_len) - + num_tokens_occupied; } else { // Both are specified result = std::min( - std::min(seq_len, max_context_len) - num_prompt_tokens, - max_new_tokens); + std::min(static_cast(seq_len), max_context_len) - + num_tokens_occupied, + static_cast(max_new_tokens)); } // Ensure result is not negative - return std::max(0, result); + return static_cast(std::max(static_cast(0), result)); } }; diff --git a/extension/llm/runner/pybindings.cpp b/extension/llm/runner/pybindings.cpp index ecd49e6341a..3188b5390c4 100644 --- a/extension/llm/runner/pybindings.cpp +++ b/extension/llm/runner/pybindings.cpp @@ -297,7 +297,7 @@ PYBIND11_MODULE(_llm_runner, m) { "resolve_max_new_tokens", &GenerationConfig::resolve_max_new_tokens, py::arg("max_context_len"), - py::arg("num_prompt_tokens"), + py::arg("num_tokens_occupied"), "Resolve the maximum number of new tokens to generate based on constraints") .def("__repr__", [](const GenerationConfig& config) { return "