2121
2222import click
2323import numpy as np
24- from tqdm import tqdm
2524
2625import tensorrt_llm .profiler as profiler
2726from tensorrt_llm .inputs import prompt_inputs
4443from ..llmapi import RequestOutput
4544from ..logger import logger
4645from ..sampling_params import SamplingParams
47- from .interface import (Evaluator , dump_inference_results ,
48- get_chat_template_kwargs )
46+ from .interface import (RESULT_WAIT_TIMEOUT_SECS , Evaluator ,
47+ dump_inference_results , get_chat_template_kwargs )
48+ from .progress import tqdm_with_time_prefix
4949
5050# NOTE: lm_eval uses "<image>" as the default image placeholder
5151# https://github.com/EleutherAI/lm-evaluation-harness/blob/7f04db12d2f8e7a99a0830d99eb78130e1ba2122/lm_eval/models/hf_vlms.py#L25
@@ -162,9 +162,9 @@ def _get_sampling_params(self, gen_kwargs: dict) -> SamplingParams:
162162 def generate_until (self , requests , disable_tqdm : bool = False ) -> List [str ]:
163163 profiler .start ("trtllm exec" )
164164 results = []
165- for request in tqdm (requests ,
166- desc = "Submitting requests" ,
167- disable = disable_tqdm ):
165+ for request in tqdm_with_time_prefix (requests ,
166+ desc = "Submitting requests" ,
167+ disable = disable_tqdm ):
168168 prompt , gen_kwargs = request .args
169169 sampling_params = self ._get_sampling_params (gen_kwargs )
170170 output = self .llm .generate_async (prompt ,
@@ -173,10 +173,10 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
173173 results .append (output )
174174
175175 outputs = []
176- for output in tqdm (results ,
177- desc = "Fetching responses" ,
178- disable = disable_tqdm ):
179- outputs .append (output .result ())
176+ for output in tqdm_with_time_prefix (results ,
177+ desc = "Fetching responses" ,
178+ disable = disable_tqdm ):
179+ outputs .append (output .result (timeout = RESULT_WAIT_TIMEOUT_SECS ))
180180
181181 if self .output_dir :
182182 dump_inference_results (self .output_dir , outputs ,
@@ -405,9 +405,9 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
405405 """
406406 profiler .start ("trtllm exec" )
407407 results = []
408- for request in tqdm (requests ,
409- desc = "Submitting requests" ,
410- disable = disable_tqdm ):
408+ for request in tqdm_with_time_prefix (requests ,
409+ desc = "Submitting requests" ,
410+ disable = disable_tqdm ):
411411
412412 # NOTE: For now, only this part is different from the original generate_until
413413 prompt , gen_kwargs , media_data = request .args
@@ -431,10 +431,10 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
431431 results .append (output )
432432
433433 outputs = []
434- for output in tqdm (results ,
435- desc = "Fetching responses" ,
436- disable = disable_tqdm ):
437- outputs .append (output .result ())
434+ for output in tqdm_with_time_prefix (results ,
435+ desc = "Fetching responses" ,
436+ disable = disable_tqdm ):
437+ outputs .append (output .result (timeout = RESULT_WAIT_TIMEOUT_SECS ))
438438
439439 if self .output_dir :
440440 dump_inference_results (self .output_dir , outputs ,
0 commit comments