Skip to content

Commit 7697745

Browse files
b8zhongkcz358
authored andcommitted
[fix] batch size in openai compatible endpoint (#835)
* more * more * more * more * more * more * more * more * more * more * more * more * more * more
1 parent df477b1 commit 7697745

4 files changed

Lines changed: 275 additions & 176 deletions

File tree

lmms_eval/models/chat/openai_compatible.py

Lines changed: 106 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,19 @@
1-
import base64
21
import json
3-
import os
42
import time
5-
from io import BytesIO
6-
from typing import List, Tuple, Union
3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
from typing import List
75

8-
import numpy as np
9-
import requests as url_requests
10-
from accelerate import Accelerator, DistributedType
116
from tqdm import tqdm
127

13-
from lmms_eval.api.instance import Instance
14-
from lmms_eval.api.model import lmms
158
from lmms_eval.api.registry import register_model
169

1710
try:
1811
from decord import VideoReader, cpu
1912
except ImportError:
2013
pass
2114

22-
from dotenv import find_dotenv, load_dotenv
15+
from dotenv import load_dotenv
2316
from loguru import logger as eval_logger
24-
from openai import AzureOpenAI, OpenAI
25-
from PIL import Image
2617

2718
from lmms_eval.models.model_utils.gen_metrics import log_metrics
2819
from lmms_eval.models.simple.openai_compatible import (
@@ -39,89 +30,117 @@ class OpenAICompatible(OpenAICompatibleSimple):
3930

4031
def generate_until(self, requests) -> List[str]:
4132
res = []
42-
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
33+
34+
batch_size = getattr(self, "batch_size_per_gpu", 1)
35+
batched_requests = [requests[i : i + batch_size] for i in range(0, len(requests), batch_size)]
36+
pbar = tqdm(total=len(batched_requests), disable=(self.rank != 0), desc="Model Responding")
4337

4438
e2e_latency = 0
4539
total_tokens = 0
46-
for ctx, doc_to_messages, gen_kwargs, doc_id, task, split in [reg.args for reg in requests]:
47-
if self.continual_mode is True and self.cache_mode == "resume":
48-
doc_uuid = f"{task}___{split}___{doc_id}"
49-
if doc_uuid in self.response_cache:
50-
response_text = self.response_cache[doc_uuid]
51-
if response_text:
52-
res.append(response_text)
53-
pbar.update(1)
54-
continue
55-
56-
chat_messages = doc_to_messages(self.task_dict[task][split][doc_id])
57-
chat_messages: ChatMessages = ChatMessages(**{"messages": chat_messages})
58-
59-
payload = {"messages": chat_messages.to_openai_messages()}
60-
payload["model"] = self.model_version
61-
62-
if "max_new_tokens" not in gen_kwargs:
63-
gen_kwargs["max_new_tokens"] = 1024
64-
if gen_kwargs["max_new_tokens"] > 4096:
65-
gen_kwargs["max_new_tokens"] = 4096
66-
if "temperature" not in gen_kwargs:
67-
gen_kwargs["temperature"] = 0
68-
if "top_p" not in gen_kwargs:
69-
gen_kwargs["top_p"] = None
70-
if "num_beams" not in gen_kwargs:
71-
gen_kwargs["num_beams"] = 1
72-
73-
# payload["max_completion_tokens"] = gen_kwargs["max_new_tokens"]
74-
payload["max_tokens"] = gen_kwargs["max_new_tokens"]
75-
payload["temperature"] = gen_kwargs["temperature"]
76-
77-
if "o1" in self.model_version or "o3" in self.model_version or "o4" in self.model_version:
78-
# del payload["max_output_tokens"]
79-
del payload["temperature"]
80-
payload.pop("max_tokens")
81-
payload["reasoning_effort"] = "medium"
82-
payload["response_format"] = {"type": "text"}
83-
payload["max_completion_tokens"] = gen_kwargs["max_new_tokens"]
84-
85-
for attempt in range(self.max_retries):
86-
try:
87-
start_time = time.time()
88-
response = self.client.chat.completions.create(**payload)
89-
end_time = time.time()
90-
91-
response_text = response.choices[0].message.content
92-
93-
# Calculate timing metrics
94-
e2e_latency += end_time - start_time
95-
96-
# Get token counts from response if available
97-
if hasattr(response, "usage"):
98-
total_tokens += response.usage.completion_tokens
99-
else:
100-
# Approximate token count if not provided
101-
total_tokens += len(response_text.split())
102-
103-
break # If successful, break out of the loop
104-
105-
except Exception as e:
106-
error_msg = str(e)
107-
eval_logger.info(f"Attempt {attempt + 1}/{self.max_retries} failed with error: {error_msg}")
108-
109-
# On last attempt, log error and set empty response
110-
if attempt == self.max_retries - 1:
111-
eval_logger.error(f"All {self.max_retries} attempts failed. Last error: {error_msg}")
112-
response_text = ""
113-
else:
114-
time.sleep(self.timeout)
115-
116-
res.append(response_text)
117-
pbar.update(1)
11840

119-
if self.continual_mode is True: # Cache the response
41+
for batch_requests in batched_requests:
42+
batch_payloads = []
43+
batch_doc_uuids = []
44+
batch_responses = []
45+
46+
for req in batch_requests:
47+
ctx, doc_to_messages, gen_kwargs, doc_id, task, split = req.args
12048
doc_uuid = f"{task}___{split}___{doc_id}"
121-
self.response_cache[doc_uuid] = response_text
49+
batch_doc_uuids.append(doc_uuid)
50+
51+
if self.continual_mode is True and self.cache_mode == "resume":
52+
if doc_uuid in self.response_cache:
53+
response_text = self.response_cache[doc_uuid]
54+
if response_text:
55+
batch_responses.append(response_text)
56+
continue
57+
58+
chat_messages_raw = doc_to_messages(self.task_dict[task][split][doc_id])
59+
chat_messages: ChatMessages = ChatMessages(**{"messages": chat_messages_raw})
60+
61+
payload = {"messages": chat_messages.to_openai_messages()}
62+
payload["model"] = self.model_version
63+
64+
if "max_new_tokens" not in gen_kwargs:
65+
gen_kwargs["max_new_tokens"] = 1024
66+
if gen_kwargs["max_new_tokens"] > 4096:
67+
gen_kwargs["max_new_tokens"] = 4096
68+
if "temperature" not in gen_kwargs:
69+
gen_kwargs["temperature"] = 0
70+
if "top_p" not in gen_kwargs:
71+
gen_kwargs["top_p"] = None
72+
if "num_beams" not in gen_kwargs:
73+
gen_kwargs["num_beams"] = 1
74+
75+
payload["max_tokens"] = gen_kwargs["max_new_tokens"]
76+
payload["temperature"] = gen_kwargs["temperature"]
77+
78+
if "o1" in self.model_version or "o3" in self.model_version or "o4" in self.model_version:
79+
del payload["temperature"]
80+
payload.pop("max_tokens")
81+
payload["reasoning_effort"] = "medium"
82+
payload["response_format"] = {"type": "text"}
83+
payload["max_completion_tokens"] = gen_kwargs["max_new_tokens"]
84+
85+
batch_payloads.append(payload)
86+
batch_responses.append(None)
87+
88+
def process_single_request(payload, i):
89+
if batch_responses[i] is not None:
90+
return batch_responses[i], i, 0, 0
91+
92+
for attempt in range(self.max_retries):
93+
try:
94+
start_time = time.time()
95+
response = self.client.chat.completions.create(**payload)
96+
end_time = time.time()
97+
98+
response_text = response.choices[0].message.content
99+
latency = end_time - start_time
100+
101+
tokens = 0
102+
if hasattr(response, "usage"):
103+
tokens = response.usage.completion_tokens
104+
else:
105+
tokens = len(response_text.split())
106+
107+
return response_text, i, latency, tokens
108+
109+
except Exception as e:
110+
error_msg = str(e)
111+
eval_logger.info(f"Attempt {attempt + 1}/{self.max_retries} failed with error: {error_msg}")
112+
113+
if attempt == self.max_retries - 1:
114+
eval_logger.error(f"All {self.max_retries} attempts failed. Last error: {error_msg}")
115+
return "", i, 0, 0
116+
else:
117+
time.sleep(self.timeout)
118+
119+
return "", i, 0, 0
120+
121+
tasks_to_run = [(payload, i) for i, payload in enumerate(batch_payloads) if batch_responses[i] is None]
122+
123+
if tasks_to_run:
124+
max_workers = min(len(tasks_to_run), 32)
125+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
126+
future_to_index = {executor.submit(process_single_request, payload, i): i for payload, i in tasks_to_run}
127+
128+
for future in as_completed(future_to_index):
129+
response_text, i, latency, tokens = future.result()
130+
batch_responses[i] = response_text
131+
e2e_latency += latency
132+
total_tokens += tokens
133+
134+
if self.continual_mode is True:
135+
for doc_uuid, response_text in zip(batch_doc_uuids, batch_responses):
136+
if response_text is not None:
137+
self.response_cache[doc_uuid] = response_text
122138
with open(self.response_persistent_file, "w") as f:
123139
json.dump(self.response_cache, f)
124140

141+
res.extend([r for r in batch_responses if r is not None])
142+
pbar.update(1)
143+
125144
# Calculate average speed
126145
avg_speed = total_tokens / e2e_latency if e2e_latency > 0 else 0
127146
# Log metrics

0 commit comments

Comments
 (0)