Skip to content

Commit eef7921

Browse files
authored
Widen resolve_max_new_tokens parameters to int64_t and rename for clarity (#18917)
Differential Revision: D99769848 Pull Request resolved: #18917
1 parent c3f3d12 commit eef7921

6 files changed

Lines changed: 53 additions & 38 deletions

File tree

docs/source/llm/run-with-c-plus-plus.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,13 @@ struct GenerationConfig {
183183
int32_t num_eos = 0; // Number of EOS tokens to add
184184

185185
// Helper method to resolve the actual max_new_tokens based on constraints
186-
int32_t resolve_max_new_tokens(int32_t max_context_len, int32_t num_prompt_tokens) const;
186+
int32_t resolve_max_new_tokens(int64_t max_context_len, int64_t num_tokens_occupied) const;
187187
};
188188
```
189189
190190
The `resolve_max_new_tokens` method handles the logic of determining how many tokens can be generated based on:
191191
- The model's maximum context length
192-
- The number of tokens in the prompt
192+
- The number of token positions already occupied in the context window
193193
- The user-specified maximum sequence length and maximum new tokens
194194
195195
### Implementation Components

extension/llm/runner/_llm_runner.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,15 @@ class GenerationConfig:
4747
...
4848

4949
def resolve_max_new_tokens(
50-
self, max_context_len: int, num_prompt_tokens: int
50+
self, max_context_len: int, num_tokens_occupied: int
5151
) -> int:
5252
"""
5353
Resolve the maximum number of new tokens to generate based on constraints.
5454
5555
Args:
5656
max_context_len: The maximum context length supported by the model
57-
num_prompt_tokens: The number of tokens in the input prompt
57+
num_tokens_occupied: The number of token positions already occupied
58+
in the context window (e.g. pos after prefill)
5859
5960
Returns:
6061
The resolved maximum number of new tokens to generate

extension/llm/runner/irunner.h

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#pragma once
1212

13+
#include <algorithm>
1314
#include <cstdint>
1415
#include <functional>
1516
#include <memory>
@@ -65,36 +66,41 @@ struct GenerationConfig {
6566
*
6667
* This method calculates the maximum number of new tokens that can be
6768
* generated considering both seq_len and max_new_tokens constraints, as well
68-
* as the model's maximum context length and the number of tokens in the
69-
* prompt.
69+
* as the model's maximum context length and how many token positions are
70+
* already occupied (e.g. by prior turns and the current prompt).
7071
*
7172
* @param max_context_len The maximum context length supported by the model
72-
* @param num_prompt_tokens The number of tokens in the input prompt
73+
* @param num_tokens_occupied The number of token positions already occupied
74+
* in the context window (e.g. pos_ after prefill)
7375
* @return The resolved maximum number of new tokens to generate
7476
*/
7577
int32_t resolve_max_new_tokens(
76-
int32_t max_context_len,
77-
int32_t num_prompt_tokens) const {
78-
int32_t result;
78+
int64_t max_context_len,
79+
int64_t num_tokens_occupied) const {
80+
int64_t result;
7981

8082
if (seq_len == -1 && max_new_tokens == -1) {
81-
// Both are -1, use max context len minus prompt tokens
82-
result = max_context_len - num_prompt_tokens;
83+
// Both are -1, use max context len minus occupied tokens
84+
result = max_context_len - num_tokens_occupied;
8385
} else if (seq_len == -1 && max_new_tokens != -1) {
8486
// Only max_new_tokens is specified
85-
result = std::min(max_new_tokens, max_context_len - num_prompt_tokens);
87+
result = std::min(
88+
static_cast<int64_t>(max_new_tokens),
89+
max_context_len - num_tokens_occupied);
8690
} else if (seq_len != -1 && max_new_tokens == -1) {
8791
// Only seq_len is specified
88-
result = std::min(seq_len, max_context_len) - num_prompt_tokens;
92+
result = std::min(static_cast<int64_t>(seq_len), max_context_len) -
93+
num_tokens_occupied;
8994
} else {
9095
// Both are specified
9196
result = std::min(
92-
std::min(seq_len, max_context_len) - num_prompt_tokens,
93-
max_new_tokens);
97+
std::min(static_cast<int64_t>(seq_len), max_context_len) -
98+
num_tokens_occupied,
99+
static_cast<int64_t>(max_new_tokens));
94100
}
95101

96102
// Ensure result is not negative
97-
return std::max(0, result);
103+
return static_cast<int32_t>(std::max(static_cast<int64_t>(0), result));
98104
}
99105
};
100106

extension/llm/runner/pybindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ PYBIND11_MODULE(_llm_runner, m) {
297297
"resolve_max_new_tokens",
298298
&GenerationConfig::resolve_max_new_tokens,
299299
py::arg("max_context_len"),
300-
py::arg("num_prompt_tokens"),
300+
py::arg("num_tokens_occupied"),
301301
"Resolve the maximum number of new tokens to generate based on constraints")
302302
.def("__repr__", [](const GenerationConfig& config) {
303303
return "<GenerationConfig max_new_tokens=" +

extension/llm/runner/test/test_generation_config.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,19 @@ TEST_F(GenerationConfigTest, TestResolveMaxNewTokensBothDefault) {
2020
GenerationConfig config;
2121
// Default values: seq_len = -1, max_new_tokens = -1
2222

23-
// max_context_len = 100, num_prompt_tokens = 20
24-
// Expected: max_context_len - num_prompt_tokens = 100 - 20 = 80
23+
// max_context_len = 100, num_tokens_occupied = 20
24+
// Expected: max_context_len - num_tokens_occupied = 100 - 20 = 80
2525
EXPECT_EQ(config.resolve_max_new_tokens(100, 20), 80);
2626

27-
// max_context_len = 50, num_prompt_tokens = 30
28-
// Expected: max_context_len - num_prompt_tokens = 50 - 30 = 20
27+
// max_context_len = 50, num_tokens_occupied = 30
28+
// Expected: max_context_len - num_tokens_occupied = 50 - 30 = 20
2929
EXPECT_EQ(config.resolve_max_new_tokens(50, 30), 20);
3030

31-
// Edge case: num_prompt_tokens equals max_context_len
31+
// Edge case: num_tokens_occupied equals max_context_len
3232
// Expected: 0 (no tokens left)
3333
EXPECT_EQ(config.resolve_max_new_tokens(40, 40), 0);
3434

35-
// Edge case: num_prompt_tokens exceeds max_context_len
35+
// Edge case: num_tokens_occupied exceeds max_context_len
3636
// Expected: 0 (no tokens left, and we ensure non-negative result)
3737
EXPECT_EQ(config.resolve_max_new_tokens(30, 50), 0);
3838
}
@@ -43,17 +43,17 @@ TEST_F(GenerationConfigTest, TestResolveMaxNewTokensOnlyMaxNewTokens) {
4343
config.seq_len = -1;
4444
config.max_new_tokens = 25;
4545

46-
// max_context_len = 100, num_prompt_tokens = 20
46+
// max_context_len = 100, num_tokens_occupied = 20
4747
// Available tokens: 100 - 20 = 80
4848
// Expected: min(max_new_tokens, available) = min(25, 80) = 25
4949
EXPECT_EQ(config.resolve_max_new_tokens(100, 20), 25);
5050

51-
// max_context_len = 50, num_prompt_tokens = 40
51+
// max_context_len = 50, num_tokens_occupied = 40
5252
// Available tokens: 50 - 40 = 10
5353
// Expected: min(max_new_tokens, available) = min(25, 10) = 10
5454
EXPECT_EQ(config.resolve_max_new_tokens(50, 40), 10);
5555

56-
// Edge case: num_prompt_tokens equals max_context_len
56+
// Edge case: num_tokens_occupied equals max_context_len
5757
// Available tokens: 0
5858
// Expected: 0 (no tokens left)
5959
EXPECT_EQ(config.resolve_max_new_tokens(40, 40), 0);
@@ -65,21 +65,21 @@ TEST_F(GenerationConfigTest, TestResolveMaxNewTokensOnlySeqLen) {
6565
config.seq_len = 50;
6666
config.max_new_tokens = -1;
6767

68-
// max_context_len = 100, num_prompt_tokens = 20
68+
// max_context_len = 100, num_tokens_occupied = 20
6969
// Effective seq_len: min(seq_len, max_context_len) = min(50, 100) = 50
70-
// Expected: effective_seq_len - num_prompt_tokens = 50 - 20 = 30
70+
// Expected: effective_seq_len - num_tokens_occupied = 50 - 20 = 30
7171
EXPECT_EQ(config.resolve_max_new_tokens(100, 20), 30);
7272

73-
// max_context_len = 40, num_prompt_tokens = 20
73+
// max_context_len = 40, num_tokens_occupied = 20
7474
// Effective seq_len: min(seq_len, max_context_len) = min(50, 40) = 40
75-
// Expected: effective_seq_len - num_prompt_tokens = 40 - 20 = 20
75+
// Expected: effective_seq_len - num_tokens_occupied = 40 - 20 = 20
7676
EXPECT_EQ(config.resolve_max_new_tokens(40, 20), 20);
7777

78-
// Edge case: num_prompt_tokens equals effective seq_len
78+
// Edge case: num_tokens_occupied equals effective seq_len
7979
// Expected: 0 (no tokens left)
8080
EXPECT_EQ(config.resolve_max_new_tokens(100, 50), 0);
8181

82-
// Edge case: num_prompt_tokens exceeds effective seq_len
82+
// Edge case: num_tokens_occupied exceeds effective seq_len
8383
// Expected: 0 (no tokens left, and we ensure non-negative result)
8484
EXPECT_EQ(config.resolve_max_new_tokens(100, 60), 0);
8585
}
@@ -90,19 +90,19 @@ TEST_F(GenerationConfigTest, TestResolveMaxNewTokensBothSpecified) {
9090
config.seq_len = 50;
9191
config.max_new_tokens = 25;
9292

93-
// max_context_len = 100, num_prompt_tokens = 20
93+
// max_context_len = 100, num_tokens_occupied = 20
9494
// Effective seq_len: min(seq_len, max_context_len) = min(50, 100) = 50
95-
// Available tokens: effective_seq_len - num_prompt_tokens = 50 - 20 = 30
95+
// Available tokens: effective_seq_len - num_tokens_occupied = 50 - 20 = 30
9696
// Expected: min(max_new_tokens, available) = min(25, 30) = 25
9797
EXPECT_EQ(config.resolve_max_new_tokens(100, 20), 25);
9898

99-
// max_context_len = 40, num_prompt_tokens = 20
99+
// max_context_len = 40, num_tokens_occupied = 20
100100
// Effective seq_len: min(seq_len, max_context_len) = min(50, 40) = 40
101-
// Available tokens: effective_seq_len - num_prompt_tokens = 40 - 20 = 20
101+
// Available tokens: effective_seq_len - num_tokens_occupied = 40 - 20 = 20
102102
// Expected: min(max_new_tokens, available) = min(25, 20) = 20
103103
EXPECT_EQ(config.resolve_max_new_tokens(40, 20), 20);
104104

105-
// Edge case: num_prompt_tokens equals effective seq_len
105+
// Edge case: num_tokens_occupied equals effective seq_len
106106
// Available tokens: 0
107107
// Expected: 0 (no tokens left)
108108
EXPECT_EQ(config.resolve_max_new_tokens(40, 40), 0);

extension/llm/runner/test/test_runner_pybindings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ def test_resolve_max_new_tokens(self):
9797
result = config.resolve_max_new_tokens(1024, 100)
9898
self.assertEqual(result, 0) # max(0, 50 - 100)
9999

100+
# Test case 6: Use keyword argument with new name
101+
config.seq_len = -1
102+
config.max_new_tokens = -1
103+
result = config.resolve_max_new_tokens(
104+
max_context_len=1024, num_tokens_occupied=100
105+
)
106+
self.assertEqual(result, 924)
107+
100108
def test_repr(self):
101109
"""Test the string representation."""
102110
config = GenerationConfig()

0 commit comments

Comments
 (0)