22
33This script focuses on eval-style JSONL dumps where each row contains OpenAI
44chat ``messages``, or a string/list ``prompt`` (e.g. dapo-math-17k). List-type
5- ``prompt`` values are treated as message lists. It records streaming latency traces,
5+ ``prompt`` values are treated as message lists. Optional per-row ``tools`` and
6+ ``tool_choice`` fields are forwarded to ``/v1/chat/completions``. It records streaming latency traces,
67aggregates TTFT/ITL/TPOT metrics, and writes table plus report artifacts for concurrency/RPS sweeps.
78
89Generation options include ``--output-tokens`` (``max_completion_tokens``),
@@ -42,6 +43,8 @@ class BenchmarkRequest:
4243 messages : list [dict [str , Any ]] = field (default_factory = list )
4344 input_ids : list [int ] | None = None
4445 image_data : Any = None
46+ tools : list [dict [str , Any ]] | None = None
47+ tool_choice : Any | None = None
4548
4649
4750@dataclass
@@ -73,6 +76,7 @@ class RequestTrace:
7376 chunk_times : list [float ] = field (default_factory = list )
7477 prompt_tokens : int = 0
7578 completion_tokens : int = 0
79+ cached_tokens : int = 0
7680 usage_available : bool = False
7781 generated_text : str = ''
7882 reasoning_text : str = ''
@@ -205,6 +209,21 @@ def _extract_messages(row: dict[str, Any]) -> list[dict[str, Any]]:
205209 raise ValueError ('row must contain messages or prompt' )
206210
207211
212+ def _extract_tools (row : dict [str , Any ]) -> list [dict [str , Any ]] | None :
213+ tools = row .get ('tools' )
214+ if not tools :
215+ return None
216+ if not isinstance (tools , list ):
217+ raise ValueError ('tools must be a list when present' )
218+ return tools
219+
220+
221+ def _extract_tool_choice (row : dict [str , Any ]) -> Any | None :
222+ if 'tool_choice' not in row :
223+ return None
224+ return row ['tool_choice' ]
225+
226+
208227def _normalize_row (
209228 row : dict [str , Any ],
210229 dataset : str ,
@@ -213,23 +232,35 @@ def _normalize_row(
213232) -> BenchmarkRequest :
214233 request_id = str (row .get ('id' , f'{ dataset } -{ row_index } ' ))
215234 messages = _extract_messages (row )
235+ tools = _extract_tools (row )
236+ tool_choice = _extract_tool_choice (row )
216237
217238 if tokenizer is not None :
218- prompt_str = tokenizer .apply_chat_template (
219- messages ,
220- tokenize = False ,
221- add_generation_prompt = True ,
222- )
239+ template_kwargs : dict [str , Any ] = {
240+ 'tokenize' : False ,
241+ 'add_generation_prompt' : True ,
242+ }
243+ if tools is not None :
244+ template_kwargs ['tools' ] = tools
245+ prompt_str = tokenizer .apply_chat_template (messages , ** template_kwargs )
223246 return BenchmarkRequest (
224247 dataset = dataset ,
225248 id = request_id ,
226249 input_ids = tokenizer .encode (prompt_str , add_special_tokens = False ),
227250 image_data = row .get ('image_data' ),
251+ tools = tools ,
252+ tool_choice = tool_choice ,
228253 )
229254
230255 if not messages :
231256 raise ValueError (f'row { row_index } in { dataset } has invalid messages' )
232- return BenchmarkRequest (dataset = dataset , id = request_id , messages = messages )
257+ return BenchmarkRequest (
258+ dataset = dataset ,
259+ id = request_id ,
260+ messages = messages ,
261+ tools = tools ,
262+ tool_choice = tool_choice ,
263+ )
233264
234265
235266def _read_raw_rows (
@@ -326,6 +357,13 @@ def parse_sse_line(line: bytes | str) -> SSEEvent:
326357 )
327358
328359
360+ def _cached_tokens_from_usage (usage : dict [str , Any ] | None ) -> int :
361+ if not usage :
362+ return 0
363+ details = usage .get ('prompt_tokens_details' ) or {}
364+ return int (details .get ('cached_tokens' , 0 ) or 0 )
365+
366+
329367def build_payload (
330368 request : BenchmarkRequest ,
331369 model : str ,
@@ -372,6 +410,10 @@ def build_payload(
372410 payload ['logprobs' ] = True
373411 if top_logprobs is not None :
374412 payload ['top_logprobs' ] = top_logprobs
413+ if request .tools :
414+ payload ['tools' ] = request .tools
415+ if request .tool_choice is not None :
416+ payload ['tool_choice' ] = request .tool_choice
375417 if extra_body :
376418 payload .update (extra_body )
377419 return payload
@@ -482,6 +524,7 @@ async def request_chat_completion(
482524 trace .completion_tokens = int (
483525 event .usage .get ('completion_tokens' , trace .completion_tokens ) or 0
484526 )
527+ trace .cached_tokens = _cached_tokens_from_usage (event .usage )
485528 if event .routed_experts and shared_store is not None :
486529 try :
487530 await fetch_routed_experts (shared_store , event .routed_experts )
@@ -632,6 +675,7 @@ def aggregate_traces(traces: Sequence[RequestTrace]) -> list[dict[str, Any]]:
632675 duration = max (end - start , 0.0 )
633676 total_input = sum (trace .prompt_tokens for trace in completed )
634677 total_output = sum (trace .completion_tokens for trace in completed )
678+ total_cached = sum (trace .cached_tokens for trace in completed )
635679 itls = [itl for trace in completed for itl in trace .itls_s ]
636680
637681 summary : dict [str , Any ] = {
@@ -646,6 +690,8 @@ def aggregate_traces(traces: Sequence[RequestTrace]) -> list[dict[str, Any]]:
646690 'duration_s' : duration ,
647691 'total_input_tokens' : total_input ,
648692 'total_output_tokens' : total_output ,
693+ 'total_cached_tokens' : total_cached ,
694+ 'cache_hit_rate' : total_cached / total_input if total_input > 0 else 0.0 ,
649695 'request_throughput_req_s' : len (completed ) / duration if duration > 0 else 0.0 ,
650696 'input_throughput_tok_s' : total_input / duration if duration > 0 else 0.0 ,
651697 'output_throughput_tok_s' : total_output / duration if duration > 0 else 0.0 ,
@@ -691,6 +737,7 @@ def _write_requests_csv(path: Path, rows: Sequence[dict[str, Any]]) -> None:
691737 'e2e_latency_s' ,
692738 'prompt_tokens' ,
693739 'completion_tokens' ,
740+ 'cached_tokens' ,
694741 'usage_available' ,
695742 'finish_reason' ,
696743 'error' ,
0 commit comments