Skip to content

Commit 55c7eba

Browse files
authored
fix: preserve OpenAI max_new_tokens (#1318)
1 parent a31a7de commit 55c7eba

3 files changed

Lines changed: 129 additions & 3 deletions

File tree

lmms_eval/models/chat/openai.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
log_usage,
2323
)
2424
from lmms_eval.models.simple.openai import OpenAICompatible as OpenAICompatibleSimple
25+
from lmms_eval.models.simple.openai import _get_max_new_tokens
2526
from lmms_eval.protocol import ChatMessages
2627

2728
VideoReader, _ = optional_import("decord", "VideoReader")
@@ -177,7 +178,7 @@ def build_payload_for_index(global_index: int) -> dict:
177178
chat_messages_raw = doc_to_messages(self.task_dict[task][split][doc_id])
178179
chat_messages: ChatMessages = ChatMessages(**{"messages": chat_messages_raw})
179180
request_gen_kwargs = dict(gen_kwargs)
180-
max_new_tokens = min(request_gen_kwargs.get("max_new_tokens", 1024), 4096)
181+
max_new_tokens = _get_max_new_tokens(request_gen_kwargs)
181182
temperature = request_gen_kwargs.get("temperature", 0)
182183

183184
if self.video_fps is not None and self.video_fps > 0:
@@ -196,7 +197,7 @@ def build_payload_for_index(global_index: int) -> dict:
196197
payload.pop("temperature")
197198
payload.pop("max_tokens")
198199
payload["response_format"] = {"type": "text"}
199-
payload["max_completion_tokens"] = 5000
200+
payload["max_completion_tokens"] = max_new_tokens
200201

201202
return payload
202203

lmms_eval/models/simple/openai.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ def _normalize_openai_message_content(content) -> str:
6363
return str(content)
6464

6565

66+
def _get_max_new_tokens(gen_kwargs: dict) -> int:
67+
return gen_kwargs.get("max_new_tokens", 1024)
68+
69+
6670
@register_model("openai")
6771
class OpenAICompatible(lmms):
6872
def __init__(
@@ -440,7 +444,7 @@ def build_payload_for_index(global_index: int):
440444
imgs.append(self.encode_image(visual))
441445

442446
request_gen_kwargs = dict(gen_kwargs)
443-
max_new_tokens = min(request_gen_kwargs.get("max_new_tokens", 1024), 4096)
447+
max_new_tokens = _get_max_new_tokens(request_gen_kwargs)
444448
temperature = request_gen_kwargs.get("temperature", 0)
445449

446450
payload = {

test/models/test_openai.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
from types import SimpleNamespace
5+
6+
from lmms_eval.models.chat.openai import OpenAICompatible as ChatOpenAICompatible
7+
from lmms_eval.models.simple.openai import OpenAICompatible as SimpleOpenAICompatible
8+
9+
10+
def _fake_response(content: str = "ok") -> SimpleNamespace:
11+
message = SimpleNamespace(content=content)
12+
choice = SimpleNamespace(message=message, finish_reason="stop", index=0)
13+
return SimpleNamespace(choices=[choice], usage=None)
14+
15+
16+
class _CaptureCompletions:
17+
def __init__(self) -> None:
18+
self.payloads: list[dict] = []
19+
20+
def create(self, **payload):
21+
self.payloads.append(payload)
22+
return _fake_response()
23+
24+
25+
def _request(*args) -> SimpleNamespace:
26+
return SimpleNamespace(args=args)
27+
28+
29+
def _configure_openai_model(
30+
model, completions: _CaptureCompletions, *, model_version: str = "gpt-4o"
31+
) -> None:
32+
model.client = SimpleNamespace(chat=SimpleNamespace(completions=completions))
33+
model.model_version = model_version
34+
model.max_retries = 1
35+
model.retry_backoff_s = 0
36+
model.num_concurrent = 1
37+
model.adaptive_concurrency = False
38+
model.adaptive_config = SimpleNamespace(max_concurrency=1)
39+
model.prefix_aware_queue = False
40+
model.prefix_hash_chars = 256
41+
model.max_frames_num = 1
42+
model.video_fps = None
43+
model._rank = 0
44+
model.task_dict = {"demo": {"test": [{"id": 0}]}}
45+
46+
47+
class TestOpenAICompatibleMaxTokens(unittest.TestCase):
48+
def test_simple_backend_preserves_requested_max_new_tokens(self):
49+
completions = _CaptureCompletions()
50+
model = SimpleOpenAICompatible.__new__(SimpleOpenAICompatible)
51+
_configure_openai_model(model, completions)
52+
53+
model.generate_until(
54+
[
55+
_request(
56+
"Describe the image",
57+
{"max_new_tokens": 8192, "temperature": 0},
58+
lambda _doc: None,
59+
0,
60+
"demo",
61+
"test",
62+
)
63+
]
64+
)
65+
66+
self.assertEqual(completions.payloads[0]["max_tokens"], 8192)
67+
68+
def test_chat_backend_preserves_requested_max_new_tokens(self):
69+
completions = _CaptureCompletions()
70+
model = ChatOpenAICompatible.__new__(ChatOpenAICompatible)
71+
_configure_openai_model(model, completions)
72+
73+
model.generate_until(
74+
[
75+
_request(
76+
"",
77+
lambda _doc: [
78+
{
79+
"role": "user",
80+
"content": [{"type": "text", "text": "Describe this"}],
81+
}
82+
],
83+
{"max_new_tokens": 32768, "temperature": 0},
84+
0,
85+
"demo",
86+
"test",
87+
)
88+
]
89+
)
90+
91+
self.assertEqual(completions.payloads[0]["max_tokens"], 32768)
92+
93+
def test_chat_reasoning_models_use_requested_completion_tokens(self):
94+
completions = _CaptureCompletions()
95+
model = ChatOpenAICompatible.__new__(ChatOpenAICompatible)
96+
_configure_openai_model(model, completions, model_version="gpt-5")
97+
98+
model.generate_until(
99+
[
100+
_request(
101+
"",
102+
lambda _doc: [
103+
{
104+
"role": "user",
105+
"content": [{"type": "text", "text": "Reason carefully"}],
106+
}
107+
],
108+
{"max_new_tokens": 32768, "temperature": 0.7},
109+
0,
110+
"demo",
111+
"test",
112+
)
113+
]
114+
)
115+
116+
self.assertNotIn("max_tokens", completions.payloads[0])
117+
self.assertEqual(completions.payloads[0]["max_completion_tokens"], 32768)
118+
119+
120+
if __name__ == "__main__":
121+
unittest.main()

0 commit comments

Comments
 (0)