Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lmms_eval/models/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
log_usage,
)
from lmms_eval.models.simple.openai import OpenAICompatible as OpenAICompatibleSimple
from lmms_eval.models.simple.openai import _get_max_new_tokens
from lmms_eval.protocol import ChatMessages

VideoReader, _ = optional_import("decord", "VideoReader")
Expand Down Expand Up @@ -177,7 +178,7 @@ def build_payload_for_index(global_index: int) -> dict:
chat_messages_raw = doc_to_messages(self.task_dict[task][split][doc_id])
chat_messages: ChatMessages = ChatMessages(**{"messages": chat_messages_raw})
request_gen_kwargs = dict(gen_kwargs)
max_new_tokens = min(request_gen_kwargs.get("max_new_tokens", 1024), 4096)
max_new_tokens = _get_max_new_tokens(request_gen_kwargs)
temperature = request_gen_kwargs.get("temperature", 0)

if self.video_fps is not None and self.video_fps > 0:
Expand All @@ -196,7 +197,7 @@ def build_payload_for_index(global_index: int) -> dict:
payload.pop("temperature")
payload.pop("max_tokens")
payload["response_format"] = {"type": "text"}
payload["max_completion_tokens"] = 5000
payload["max_completion_tokens"] = max_new_tokens

return payload

Expand Down
6 changes: 5 additions & 1 deletion lmms_eval/models/simple/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def _normalize_openai_message_content(content) -> str:
return str(content)


def _get_max_new_tokens(gen_kwargs: dict) -> int:
return gen_kwargs.get("max_new_tokens", 1024)


@register_model("openai")
class OpenAICompatible(lmms):
def __init__(
Expand Down Expand Up @@ -440,7 +444,7 @@ def build_payload_for_index(global_index: int):
imgs.append(self.encode_image(visual))

request_gen_kwargs = dict(gen_kwargs)
max_new_tokens = min(request_gen_kwargs.get("max_new_tokens", 1024), 4096)
max_new_tokens = _get_max_new_tokens(request_gen_kwargs)
temperature = request_gen_kwargs.get("temperature", 0)

payload = {
Expand Down
121 changes: 121 additions & 0 deletions test/models/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

import unittest
from types import SimpleNamespace

from lmms_eval.models.chat.openai import OpenAICompatible as ChatOpenAICompatible
from lmms_eval.models.simple.openai import OpenAICompatible as SimpleOpenAICompatible


def _fake_response(content: str = "ok") -> SimpleNamespace:
message = SimpleNamespace(content=content)
choice = SimpleNamespace(message=message, finish_reason="stop", index=0)
return SimpleNamespace(choices=[choice], usage=None)


class _CaptureCompletions:
def __init__(self) -> None:
self.payloads: list[dict] = []

def create(self, **payload):
self.payloads.append(payload)
return _fake_response()


def _request(*args) -> SimpleNamespace:
return SimpleNamespace(args=args)


def _configure_openai_model(
model, completions: _CaptureCompletions, *, model_version: str = "gpt-4o"
) -> None:
model.client = SimpleNamespace(chat=SimpleNamespace(completions=completions))
model.model_version = model_version
model.max_retries = 1
model.retry_backoff_s = 0
model.num_concurrent = 1
model.adaptive_concurrency = False
model.adaptive_config = SimpleNamespace(max_concurrency=1)
model.prefix_aware_queue = False
model.prefix_hash_chars = 256
model.max_frames_num = 1
model.video_fps = None
model._rank = 0
model.task_dict = {"demo": {"test": [{"id": 0}]}}


class TestOpenAICompatibleMaxTokens(unittest.TestCase):
def test_simple_backend_preserves_requested_max_new_tokens(self):
completions = _CaptureCompletions()
model = SimpleOpenAICompatible.__new__(SimpleOpenAICompatible)
_configure_openai_model(model, completions)

model.generate_until(
[
_request(
"Describe the image",
{"max_new_tokens": 8192, "temperature": 0},
lambda _doc: None,
0,
"demo",
"test",
)
]
)

self.assertEqual(completions.payloads[0]["max_tokens"], 8192)

def test_chat_backend_preserves_requested_max_new_tokens(self):
completions = _CaptureCompletions()
model = ChatOpenAICompatible.__new__(ChatOpenAICompatible)
_configure_openai_model(model, completions)

model.generate_until(
[
_request(
"",
lambda _doc: [
{
"role": "user",
"content": [{"type": "text", "text": "Describe this"}],
}
],
{"max_new_tokens": 32768, "temperature": 0},
0,
"demo",
"test",
)
]
)

self.assertEqual(completions.payloads[0]["max_tokens"], 32768)

def test_chat_reasoning_models_use_requested_completion_tokens(self):
completions = _CaptureCompletions()
model = ChatOpenAICompatible.__new__(ChatOpenAICompatible)
_configure_openai_model(model, completions, model_version="gpt-5")

model.generate_until(
[
_request(
"",
lambda _doc: [
{
"role": "user",
"content": [{"type": "text", "text": "Reason carefully"}],
}
],
{"max_new_tokens": 32768, "temperature": 0.7},
0,
"demo",
"test",
)
]
)

self.assertNotIn("max_tokens", completions.payloads[0])
self.assertEqual(completions.payloads[0]["max_completion_tokens"], 32768)


if __name__ == "__main__":
unittest.main()
Loading