Skip to content

Commit acd5638

Browse files
authored
[Benchmark] support input_ids for benchmark dataset (#7993)
1 parent 555a6c5 commit acd5638

2 files changed

Lines changed: 23 additions & 10 deletions

File tree

benchmarks/backend_request_func.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,16 @@ async def async_request_eb_openai_chat_completions(
356356
if request_func_input.response_format:
357357
payload["response_format"] = request_func_input.response_format
358358

359-
# 随机输入开关
359+
# Random-length input/output knob.
360360
if request_func_input.random_flag:
361361
payload["max_tokens"] = request_func_input.output_len
362362
payload["min_tokens"] = request_func_input.output_len
363-
# 随机token_ids场景
364-
if isinstance(request_func_input.prompt, list):
365-
request_func_input.prompt_token_ids = request_func_input.prompt
366-
request_func_input.prompt = ""
363+
364+
# When the prompt is a list of token ids, route through prompt_token_ids
365+
# regardless of random_flag.
366+
if isinstance(request_func_input.prompt, list):
367+
request_func_input.prompt_token_ids = request_func_input.prompt
368+
request_func_input.prompt = ""
367369

368370
# 支持传入prompt_token_ids
369371
if request_func_input.prompt_token_ids:

benchmarks/benchmark_dataset.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,19 +300,30 @@ def sample(
300300
if len(samples) >= num_requests:
301301
break
302302
json_data = entry
303-
prompt = entry["messages"][-1].get("content", "")
304-
history_QA = entry.get("messages", [])
305303
response_format = entry.get("response_format")
306304
new_output_len = int(entry.get("max_tokens", output_len if output_len else 12288))
307305

308-
if enable_multimodal_chat:
309-
prompt = self.apply_multimodal_chat_transformation(prompt, None)
306+
# If the sample already carries pre-tokenized input_ids, send them
307+
# directly via prompt_token_ids and skip the server-side
308+
# chat_template + tokenizer step.
309+
input_ids = entry.get("input_ids")
310+
if input_ids is not None:
311+
prompt = [int(x) for x in input_ids]
312+
history_QA = []
313+
prompt_len = len(prompt)
314+
else:
315+
prompt = entry["messages"][-1].get("content", "")
316+
history_QA = entry.get("messages", [])
317+
prompt_len = 0
318+
if enable_multimodal_chat:
319+
prompt = self.apply_multimodal_chat_transformation(prompt, None)
320+
310321
samples.append(
311322
SampleRequest(
312323
no=cnt,
313324
json_data=json_data,
314325
prompt=prompt,
315-
prompt_len=0,
326+
prompt_len=prompt_len,
316327
history_QA=history_QA,
317328
expected_output_len=new_output_len,
318329
response_format=response_format,

0 commit comments

Comments
 (0)