Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/llm/run-with-c-plus-plus.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,13 @@ struct GenerationConfig {
int32_t num_eos = 0; // Number of EOS tokens to add

// Helper method to resolve the actual max_new_tokens based on constraints
int32_t resolve_max_new_tokens(int32_t max_context_len, int32_t num_prompt_tokens) const;
int32_t resolve_max_new_tokens(int64_t max_context_len, int64_t num_tokens_occupied) const;
};
```

The `resolve_max_new_tokens` method handles the logic of determining how many tokens can be generated based on:
- The model's maximum context length
- The number of tokens in the prompt
- The number of token positions already occupied in the context window
- The user-specified maximum sequence length and maximum new tokens

### Implementation Components
Expand Down
5 changes: 3 additions & 2 deletions extension/llm/runner/_llm_runner.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@ class GenerationConfig:
...

def resolve_max_new_tokens(
self, max_context_len: int, num_prompt_tokens: int
self, max_context_len: int, num_tokens_occupied: int
) -> int:
"""
Resolve the maximum number of new tokens to generate based on constraints.

Args:
max_context_len: The maximum context length supported by the model
num_prompt_tokens: The number of tokens in the input prompt
num_tokens_occupied: The number of token positions already occupied
in the context window (e.g. pos after prefill)

Returns:
The resolved maximum number of new tokens to generate
Expand Down
32 changes: 19 additions & 13 deletions extension/llm/runner/irunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#pragma once

#include <algorithm>
#include <cstdint>
#include <functional>
#include <memory>
Expand Down Expand Up @@ -65,36 +66,41 @@ struct GenerationConfig {
*
* This method calculates the maximum number of new tokens that can be
* generated considering both seq_len and max_new_tokens constraints, as well
* as the model's maximum context length and the number of tokens in the
* prompt.
* as the model's maximum context length and how many token positions are
* already occupied (e.g. by prior turns and the current prompt).
*
* @param max_context_len The maximum context length supported by the model
* @param num_prompt_tokens The number of tokens in the input prompt
* @param num_tokens_occupied The number of token positions already occupied
* in the context window (e.g. pos_ after prefill)
* @return The resolved maximum number of new tokens to generate
*/
int32_t resolve_max_new_tokens(
int32_t max_context_len,
int32_t num_prompt_tokens) const {
int32_t result;
int64_t max_context_len,
int64_t num_tokens_occupied) const {
int64_t result;

if (seq_len == -1 && max_new_tokens == -1) {
// Both are -1, use max context len minus prompt tokens
result = max_context_len - num_prompt_tokens;
// Both are -1, use max context len minus occupied tokens
result = max_context_len - num_tokens_occupied;
} else if (seq_len == -1 && max_new_tokens != -1) {
// Only max_new_tokens is specified
result = std::min(max_new_tokens, max_context_len - num_prompt_tokens);
result = std::min(
static_cast<int64_t>(max_new_tokens),
max_context_len - num_tokens_occupied);
} else if (seq_len != -1 && max_new_tokens == -1) {
// Only seq_len is specified
result = std::min(seq_len, max_context_len) - num_prompt_tokens;
result = std::min(static_cast<int64_t>(seq_len), max_context_len) -
num_tokens_occupied;
} else {
// Both are specified
result = std::min(
std::min(seq_len, max_context_len) - num_prompt_tokens,
max_new_tokens);
std::min(static_cast<int64_t>(seq_len), max_context_len) -
num_tokens_occupied,
static_cast<int64_t>(max_new_tokens));
Comment on lines +87 to +99
}

// Ensure result is not negative
return std::max(0, result);
return static_cast<int32_t>(std::max(static_cast<int64_t>(0), result));
}
Comment on lines 102 to 104
};

Expand Down
2 changes: 1 addition & 1 deletion extension/llm/runner/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ PYBIND11_MODULE(_llm_runner, m) {
"resolve_max_new_tokens",
&GenerationConfig::resolve_max_new_tokens,
py::arg("max_context_len"),
py::arg("num_prompt_tokens"),
py::arg("num_tokens_occupied"),
"Resolve the maximum number of new tokens to generate based on constraints")
.def("__repr__", [](const GenerationConfig& config) {
return "<GenerationConfig max_new_tokens=" +
Expand Down
40 changes: 20 additions & 20 deletions extension/llm/runner/test/test_generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ TEST_F(GenerationConfigTest, TestResolveMaxNewTokensBothDefault) {
GenerationConfig config;
// Default values: seq_len = -1, max_new_tokens = -1

// max_context_len = 100, num_prompt_tokens = 20
// Expected: max_context_len - num_prompt_tokens = 100 - 20 = 80
// max_context_len = 100, num_tokens_occupied = 20
// Expected: max_context_len - num_tokens_occupied = 100 - 20 = 80
EXPECT_EQ(config.resolve_max_new_tokens(100, 20), 80);

// max_context_len = 50, num_prompt_tokens = 30
// Expected: max_context_len - num_prompt_tokens = 50 - 30 = 20
// max_context_len = 50, num_tokens_occupied = 30
// Expected: max_context_len - num_tokens_occupied = 50 - 30 = 20
EXPECT_EQ(config.resolve_max_new_tokens(50, 30), 20);

// Edge case: num_prompt_tokens equals max_context_len
// Edge case: num_tokens_occupied equals max_context_len
// Expected: 0 (no tokens left)
EXPECT_EQ(config.resolve_max_new_tokens(40, 40), 0);

// Edge case: num_prompt_tokens exceeds max_context_len
// Edge case: num_tokens_occupied exceeds max_context_len
// Expected: 0 (no tokens left, and we ensure non-negative result)
EXPECT_EQ(config.resolve_max_new_tokens(30, 50), 0);
}
Expand All @@ -43,17 +43,17 @@ TEST_F(GenerationConfigTest, TestResolveMaxNewTokensOnlyMaxNewTokens) {
config.seq_len = -1;
config.max_new_tokens = 25;

// max_context_len = 100, num_prompt_tokens = 20
// max_context_len = 100, num_tokens_occupied = 20
// Available tokens: 100 - 20 = 80
// Expected: min(max_new_tokens, available) = min(25, 80) = 25
EXPECT_EQ(config.resolve_max_new_tokens(100, 20), 25);

// max_context_len = 50, num_prompt_tokens = 40
// max_context_len = 50, num_tokens_occupied = 40
// Available tokens: 50 - 40 = 10
// Expected: min(max_new_tokens, available) = min(25, 10) = 10
EXPECT_EQ(config.resolve_max_new_tokens(50, 40), 10);

// Edge case: num_prompt_tokens equals max_context_len
// Edge case: num_tokens_occupied equals max_context_len
// Available tokens: 0
// Expected: 0 (no tokens left)
EXPECT_EQ(config.resolve_max_new_tokens(40, 40), 0);
Expand All @@ -65,21 +65,21 @@ TEST_F(GenerationConfigTest, TestResolveMaxNewTokensOnlySeqLen) {
config.seq_len = 50;
config.max_new_tokens = -1;

// max_context_len = 100, num_prompt_tokens = 20
// max_context_len = 100, num_tokens_occupied = 20
// Effective seq_len: min(seq_len, max_context_len) = min(50, 100) = 50
// Expected: effective_seq_len - num_prompt_tokens = 50 - 20 = 30
// Expected: effective_seq_len - num_tokens_occupied = 50 - 20 = 30
EXPECT_EQ(config.resolve_max_new_tokens(100, 20), 30);

// max_context_len = 40, num_prompt_tokens = 20
// max_context_len = 40, num_tokens_occupied = 20
// Effective seq_len: min(seq_len, max_context_len) = min(50, 40) = 40
// Expected: effective_seq_len - num_prompt_tokens = 40 - 20 = 20
// Expected: effective_seq_len - num_tokens_occupied = 40 - 20 = 20
EXPECT_EQ(config.resolve_max_new_tokens(40, 20), 20);

// Edge case: num_prompt_tokens equals effective seq_len
// Edge case: num_tokens_occupied equals effective seq_len
// Expected: 0 (no tokens left)
EXPECT_EQ(config.resolve_max_new_tokens(100, 50), 0);

// Edge case: num_prompt_tokens exceeds effective seq_len
// Edge case: num_tokens_occupied exceeds effective seq_len
// Expected: 0 (no tokens left, and we ensure non-negative result)
EXPECT_EQ(config.resolve_max_new_tokens(100, 60), 0);
}
Expand All @@ -90,19 +90,19 @@ TEST_F(GenerationConfigTest, TestResolveMaxNewTokensBothSpecified) {
config.seq_len = 50;
config.max_new_tokens = 25;

// max_context_len = 100, num_prompt_tokens = 20
// max_context_len = 100, num_tokens_occupied = 20
// Effective seq_len: min(seq_len, max_context_len) = min(50, 100) = 50
// Available tokens: effective_seq_len - num_prompt_tokens = 50 - 20 = 30
// Available tokens: effective_seq_len - num_tokens_occupied = 50 - 20 = 30
// Expected: min(max_new_tokens, available) = min(25, 30) = 25
EXPECT_EQ(config.resolve_max_new_tokens(100, 20), 25);

// max_context_len = 40, num_prompt_tokens = 20
// max_context_len = 40, num_tokens_occupied = 20
// Effective seq_len: min(seq_len, max_context_len) = min(50, 40) = 40
// Available tokens: effective_seq_len - num_prompt_tokens = 40 - 20 = 20
// Available tokens: effective_seq_len - num_tokens_occupied = 40 - 20 = 20
// Expected: min(max_new_tokens, available) = min(25, 20) = 20
EXPECT_EQ(config.resolve_max_new_tokens(40, 20), 20);

// Edge case: num_prompt_tokens equals effective seq_len
// Edge case: num_tokens_occupied equals effective seq_len
// Available tokens: 0
// Expected: 0 (no tokens left)
EXPECT_EQ(config.resolve_max_new_tokens(40, 40), 0);
Expand Down
8 changes: 8 additions & 0 deletions extension/llm/runner/test/test_runner_pybindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ def test_resolve_max_new_tokens(self):
result = config.resolve_max_new_tokens(1024, 100)
self.assertEqual(result, 0) # max(0, 50 - 100)

# Test case 6: Use keyword argument with new name
config.seq_len = -1
config.max_new_tokens = -1
result = config.resolve_max_new_tokens(
max_context_len=1024, num_tokens_occupied=100
)
self.assertEqual(result, 924)

def test_repr(self):
"""Test the string representation."""
config = GenerationConfig()
Expand Down
Loading