@@ -150,7 +150,7 @@ async def stream_one_turn(
150150 line = raw .strip ()
151151 if not line or not line .startswith (b"data:" ):
152152 continue
153- data_str = line [len (b"data:" ):].strip ()
153+ data_str = line [len (b"data:" ) :].strip ()
154154 if data_str == b"[DONE]" :
155155 break
156156 try :
@@ -219,17 +219,13 @@ async def run_session(
219219 """Run a single multi-turn dialogue session. Returns a list of per-turn
220220 stat dicts (same schema as stream_one_turn output)."""
221221 rng = random .Random (base_seed + session_id )
222- prompt , prompt_len = gen_session_initial_prompt (
223- tokenizer , start_input_len , base_seed + session_id
224- )
222+ prompt , prompt_len = gen_session_initial_prompt (tokenizer , start_input_len , base_seed + session_id )
225223
226224 per_turn : List [Dict ] = []
227225 turn_idx = 0
228226 while turn_idx < max_turns and prompt_len < max_input_len :
229227 turn_output_len = rng .randint (min_output_len , output_len )
230- result = await stream_one_turn (
231- session , url , model_name , prompt , turn_output_len
232- )
228+ result = await stream_one_turn (session , url , model_name , prompt , turn_output_len )
233229 if result is None :
234230 break
235231 per_turn .append (result )
@@ -382,28 +378,36 @@ def summarize(
382378
383379def print_summary (summary : Dict ) -> None :
384380 print ("=" * 80 )
385- print (f"Concurrency = { summary ['concurrency' ]} sessions = { summary ['num_sessions' ]} "
386- f"total_turns = { summary ['total_turns' ]} wall_time = { summary ['wall_time_s' ]} s" )
381+ print (
382+ f"Concurrency = { summary ['concurrency' ]} sessions = { summary ['num_sessions' ]} "
383+ f"total_turns = { summary ['total_turns' ]} wall_time = { summary ['wall_time_s' ]} s"
384+ )
387385 if "error" in summary :
388386 print (f" ERROR: { summary ['error' ]} " )
389387 return
390388 print (f" QPS : { summary ['QPS' ]} " )
391389 print (f" TPM (total) : { summary ['TPM_total' ]} " )
392390 print (f" TPM (prompt) : { summary ['TPM_prompt' ]} " )
393391 print (f" TPM (completion) : { summary ['TPM_completion' ]} " )
394- print (f" Cache hit ratio : { summary ['cache_hit_ratio' ] * 100 :.2f} % "
395- f"({ summary ['total_cached_prompt_tokens' ]} / { summary ['total_prompt_tokens' ]} )" )
392+ print (
393+ f" Cache hit ratio : { summary ['cache_hit_ratio' ] * 100 :.2f} % "
394+ f"({ summary ['total_cached_prompt_tokens' ]} / { summary ['total_prompt_tokens' ]} )"
395+ )
396396 print (f" Avg prompt tokens : { summary ['avg_prompt_tokens_per_turn' ]} " )
397397 print (f" Avg output tokens : { summary ['avg_completion_tokens_per_turn' ]} " )
398398 ttft = summary ["TTFT_ms" ]
399399 tpot = summary ["TPOT_ms" ]
400- print (f" TTFT ms mean={ ttft ['mean' ]} P50={ ttft .get ('P50' )} P90={ ttft .get ('P90' )} "
401- f"P95={ ttft .get ('P95' )} P99={ ttft .get ('P99' )} " )
400+ print (
401+ f" TTFT ms mean={ ttft ['mean' ]} P50={ ttft .get ('P50' )} P90={ ttft .get ('P90' )} "
402+ f"P95={ ttft .get ('P95' )} P99={ ttft .get ('P99' )} "
403+ )
402404 if tpot .get ("mean" ) is None :
403405 print (f" TPOT ms (n/a: { tpot .get ('note' )} )" )
404406 else :
405- print (f" TPOT ms mean={ tpot ['mean' ]} P50={ tpot .get ('P50' )} P90={ tpot .get ('P90' )} "
406- f"P95={ tpot .get ('P95' )} P99={ tpot .get ('P99' )} " )
407+ print (
408+ f" TPOT ms mean={ tpot ['mean' ]} P50={ tpot .get ('P50' )} P90={ tpot .get ('P90' )} "
409+ f"P95={ tpot .get ('P95' )} P99={ tpot .get ('P99' )} "
410+ )
407411
408412
409413def main () -> None :
@@ -413,7 +417,7 @@ def main() -> None:
413417 type = str ,
414418 default = "http://127.0.0.1:8088/v1/completions" ,
415419 help = "Streaming OpenAI completion endpoint. The benchmark relies on "
416- "the final SSE `usage` chunk to obtain cached_tokens." ,
420+ "the final SSE `usage` chunk to obtain cached_tokens." ,
417421 )
418422 parser .add_argument ("--tokenizer_path" , type = str , required = True )
419423 parser .add_argument (
@@ -428,30 +432,37 @@ def main() -> None:
428432 default = "1,4,8,16,32,64,128,256" ,
429433 help = "Comma-separated list of concurrency levels to sweep." ,
430434 )
431- parser .add_argument ("--start_input_len" , type = int , default = 32768 ,
432- help = "Initial prompt length in tokens per session." )
433- parser .add_argument ("--max_input_len" , type = int , default = 163840 ,
434- help = "Stop a session when its prompt exceeds this length." )
435- parser .add_argument ("--turn_input_increment" , type = int , default = 2048 ,
436- help = "Maximum new 'user' tokens sampled after each turn, on top "
437- "of the model's generated text." )
438- parser .add_argument ("--min_turn_input_increment" , type = int , default = 512 ,
439- help = "Minimum new 'user' tokens sampled after each turn." )
440- parser .add_argument ("--output_len" , type = int , default = 512 ,
441- help = "Maximum max_new_tokens sampled per turn." )
442- parser .add_argument ("--min_output_len" , type = int , default = 128 ,
443- help = "Minimum max_new_tokens sampled per turn." )
444- parser .add_argument ("--max_turns" , type = int , default = 64 ,
445- help = "Hard cap on turns per session. The session also stops once "
446- "prompt length reaches --max_input_len." )
435+ parser .add_argument (
436+ "--start_input_len" , type = int , default = 32768 , help = "Initial prompt length in tokens per session."
437+ )
438+ parser .add_argument (
439+ "--max_input_len" , type = int , default = 163840 , help = "Stop a session when its prompt exceeds this length."
440+ )
441+ parser .add_argument (
442+ "--turn_input_increment" ,
443+ type = int ,
444+ default = 2048 ,
445+ help = "Maximum new 'user' tokens sampled after each turn, on top " "of the model's generated text." ,
446+ )
447+ parser .add_argument (
448+ "--min_turn_input_increment" , type = int , default = 512 , help = "Minimum new 'user' tokens sampled after each turn."
449+ )
450+ parser .add_argument ("--output_len" , type = int , default = 512 , help = "Maximum max_new_tokens sampled per turn." )
451+ parser .add_argument ("--min_output_len" , type = int , default = 128 , help = "Minimum max_new_tokens sampled per turn." )
452+ parser .add_argument (
453+ "--max_turns" ,
454+ type = int ,
455+ default = 64 ,
456+ help = "Hard cap on turns per session. The session also stops once " "prompt length reaches --max_input_len." ,
457+ )
447458 parser .add_argument ("--seed" , type = int , default = 0 )
448459 parser .add_argument ("--request_timeout_s" , type = int , default = 3600 )
449460 parser .add_argument (
450461 "--dump_file" ,
451462 type = str ,
452463 default = "" ,
453464 help = "If set, append the per-concurrency summary dict to this JSON file. "
454- "If the file already exists and is non-empty, it is read and printed." ,
465+ "If the file already exists and is non-empty, it is read and printed." ,
455466 )
456467
457468 args = parser .parse_args ()
0 commit comments