1- import base64
21import json
3- import os
42import time
5- from io import BytesIO
6- from typing import List , Tuple , Union
3+ from concurrent . futures import ThreadPoolExecutor , as_completed
4+ from typing import List
75
8- import numpy as np
9- import requests as url_requests
10- from accelerate import Accelerator , DistributedType
116from tqdm import tqdm
127
13- from lmms_eval .api .instance import Instance
14- from lmms_eval .api .model import lmms
158from lmms_eval .api .registry import register_model
169
1710try :
1811 from decord import VideoReader , cpu
1912except ImportError :
2013 pass
2114
22- from dotenv import find_dotenv , load_dotenv
15+ from dotenv import load_dotenv
2316from loguru import logger as eval_logger
24- from openai import AzureOpenAI , OpenAI
25- from PIL import Image
2617
2718from lmms_eval .models .model_utils .gen_metrics import log_metrics
2819from lmms_eval .models .simple .openai_compatible import (
@@ -39,89 +30,117 @@ class OpenAICompatible(OpenAICompatibleSimple):
3930
4031 def generate_until (self , requests ) -> List [str ]:
4132 res = []
42- pbar = tqdm (total = len (requests ), disable = (self .rank != 0 ), desc = "Model Responding" )
33+
34+ batch_size = getattr (self , "batch_size_per_gpu" , 1 )
35+ batched_requests = [requests [i : i + batch_size ] for i in range (0 , len (requests ), batch_size )]
36+ pbar = tqdm (total = len (batched_requests ), disable = (self .rank != 0 ), desc = "Model Responding" )
4337
4438 e2e_latency = 0
4539 total_tokens = 0
46- for ctx , doc_to_messages , gen_kwargs , doc_id , task , split in [reg .args for reg in requests ]:
47- if self .continual_mode is True and self .cache_mode == "resume" :
48- doc_uuid = f"{ task } ___{ split } ___{ doc_id } "
49- if doc_uuid in self .response_cache :
50- response_text = self .response_cache [doc_uuid ]
51- if response_text :
52- res .append (response_text )
53- pbar .update (1 )
54- continue
55-
56- chat_messages = doc_to_messages (self .task_dict [task ][split ][doc_id ])
57- chat_messages : ChatMessages = ChatMessages (** {"messages" : chat_messages })
58-
59- payload = {"messages" : chat_messages .to_openai_messages ()}
60- payload ["model" ] = self .model_version
61-
62- if "max_new_tokens" not in gen_kwargs :
63- gen_kwargs ["max_new_tokens" ] = 1024
64- if gen_kwargs ["max_new_tokens" ] > 4096 :
65- gen_kwargs ["max_new_tokens" ] = 4096
66- if "temperature" not in gen_kwargs :
67- gen_kwargs ["temperature" ] = 0
68- if "top_p" not in gen_kwargs :
69- gen_kwargs ["top_p" ] = None
70- if "num_beams" not in gen_kwargs :
71- gen_kwargs ["num_beams" ] = 1
72-
73- # payload["max_completion_tokens"] = gen_kwargs["max_new_tokens"]
74- payload ["max_tokens" ] = gen_kwargs ["max_new_tokens" ]
75- payload ["temperature" ] = gen_kwargs ["temperature" ]
76-
77- if "o1" in self .model_version or "o3" in self .model_version or "o4" in self .model_version :
78- # del payload["max_output_tokens"]
79- del payload ["temperature" ]
80- payload .pop ("max_tokens" )
81- payload ["reasoning_effort" ] = "medium"
82- payload ["response_format" ] = {"type" : "text" }
83- payload ["max_completion_tokens" ] = gen_kwargs ["max_new_tokens" ]
84-
85- for attempt in range (self .max_retries ):
86- try :
87- start_time = time .time ()
88- response = self .client .chat .completions .create (** payload )
89- end_time = time .time ()
90-
91- response_text = response .choices [0 ].message .content
92-
93- # Calculate timing metrics
94- e2e_latency += end_time - start_time
95-
96- # Get token counts from response if available
97- if hasattr (response , "usage" ):
98- total_tokens += response .usage .completion_tokens
99- else :
100- # Approximate token count if not provided
101- total_tokens += len (response_text .split ())
102-
103- break # If successful, break out of the loop
104-
105- except Exception as e :
106- error_msg = str (e )
107- eval_logger .info (f"Attempt { attempt + 1 } /{ self .max_retries } failed with error: { error_msg } " )
108-
109- # On last attempt, log error and set empty response
110- if attempt == self .max_retries - 1 :
111- eval_logger .error (f"All { self .max_retries } attempts failed. Last error: { error_msg } " )
112- response_text = ""
113- else :
114- time .sleep (self .timeout )
115-
116- res .append (response_text )
117- pbar .update (1 )
11840
119- if self .continual_mode is True : # Cache the response
41+ for batch_requests in batched_requests :
42+ batch_payloads = []
43+ batch_doc_uuids = []
44+ batch_responses = []
45+
46+ for req in batch_requests :
47+ ctx , doc_to_messages , gen_kwargs , doc_id , task , split = req .args
12048 doc_uuid = f"{ task } ___{ split } ___{ doc_id } "
121- self .response_cache [doc_uuid ] = response_text
49+ batch_doc_uuids .append (doc_uuid )
50+
51+ if self .continual_mode is True and self .cache_mode == "resume" :
52+ if doc_uuid in self .response_cache :
53+ response_text = self .response_cache [doc_uuid ]
54+ if response_text :
55+ batch_responses .append (response_text )
56+ continue
57+
58+ chat_messages_raw = doc_to_messages (self .task_dict [task ][split ][doc_id ])
59+ chat_messages : ChatMessages = ChatMessages (** {"messages" : chat_messages_raw })
60+
61+ payload = {"messages" : chat_messages .to_openai_messages ()}
62+ payload ["model" ] = self .model_version
63+
64+ if "max_new_tokens" not in gen_kwargs :
65+ gen_kwargs ["max_new_tokens" ] = 1024
66+ if gen_kwargs ["max_new_tokens" ] > 4096 :
67+ gen_kwargs ["max_new_tokens" ] = 4096
68+ if "temperature" not in gen_kwargs :
69+ gen_kwargs ["temperature" ] = 0
70+ if "top_p" not in gen_kwargs :
71+ gen_kwargs ["top_p" ] = None
72+ if "num_beams" not in gen_kwargs :
73+ gen_kwargs ["num_beams" ] = 1
74+
75+ payload ["max_tokens" ] = gen_kwargs ["max_new_tokens" ]
76+ payload ["temperature" ] = gen_kwargs ["temperature" ]
77+
78+ if "o1" in self .model_version or "o3" in self .model_version or "o4" in self .model_version :
79+ del payload ["temperature" ]
80+ payload .pop ("max_tokens" )
81+ payload ["reasoning_effort" ] = "medium"
82+ payload ["response_format" ] = {"type" : "text" }
83+ payload ["max_completion_tokens" ] = gen_kwargs ["max_new_tokens" ]
84+
85+ batch_payloads .append (payload )
86+ batch_responses .append (None )
87+
88+ def process_single_request (payload , i ):
89+ if batch_responses [i ] is not None :
90+ return batch_responses [i ], i , 0 , 0
91+
92+ for attempt in range (self .max_retries ):
93+ try :
94+ start_time = time .time ()
95+ response = self .client .chat .completions .create (** payload )
96+ end_time = time .time ()
97+
98+ response_text = response .choices [0 ].message .content
99+ latency = end_time - start_time
100+
101+ tokens = 0
102+ if hasattr (response , "usage" ):
103+ tokens = response .usage .completion_tokens
104+ else :
105+ tokens = len (response_text .split ())
106+
107+ return response_text , i , latency , tokens
108+
109+ except Exception as e :
110+ error_msg = str (e )
111+ eval_logger .info (f"Attempt { attempt + 1 } /{ self .max_retries } failed with error: { error_msg } " )
112+
113+ if attempt == self .max_retries - 1 :
114+ eval_logger .error (f"All { self .max_retries } attempts failed. Last error: { error_msg } " )
115+ return "" , i , 0 , 0
116+ else :
117+ time .sleep (self .timeout )
118+
119+ return "" , i , 0 , 0
120+
121+ tasks_to_run = [(payload , i ) for i , payload in enumerate (batch_payloads ) if batch_responses [i ] is None ]
122+
123+ if tasks_to_run :
124+ max_workers = min (len (tasks_to_run ), 32 )
125+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
126+ future_to_index = {executor .submit (process_single_request , payload , i ): i for payload , i in tasks_to_run }
127+
128+ for future in as_completed (future_to_index ):
129+ response_text , i , latency , tokens = future .result ()
130+ batch_responses [i ] = response_text
131+ e2e_latency += latency
132+ total_tokens += tokens
133+
134+ if self .continual_mode is True :
135+ for doc_uuid , response_text in zip (batch_doc_uuids , batch_responses ):
136+ if response_text is not None :
137+ self .response_cache [doc_uuid ] = response_text
122138 with open (self .response_persistent_file , "w" ) as f :
123139 json .dump (self .response_cache , f )
124140
141+ res .extend ([r for r in batch_responses if r is not None ])
142+ pbar .update (1 )
143+
125144 # Calculate average speed
126145 avg_speed = total_tokens / e2e_latency if e2e_latency > 0 else 0
127146 # Log metrics
0 commit comments