Skip to content

Commit c57e062

Browse files
committed
format
1 parent 819497c commit c57e062

1 file changed

Lines changed: 44 additions & 33 deletions

File tree

test/benchmark/service/benchmark_multiturn.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ async def stream_one_turn(
150150
line = raw.strip()
151151
if not line or not line.startswith(b"data:"):
152152
continue
153-
data_str = line[len(b"data:"):].strip()
153+
data_str = line[len(b"data:") :].strip()
154154
if data_str == b"[DONE]":
155155
break
156156
try:
@@ -219,17 +219,13 @@ async def run_session(
219219
"""Run a single multi-turn dialogue session. Returns a list of per-turn
220220
stat dicts (same schema as stream_one_turn output)."""
221221
rng = random.Random(base_seed + session_id)
222-
prompt, prompt_len = gen_session_initial_prompt(
223-
tokenizer, start_input_len, base_seed + session_id
224-
)
222+
prompt, prompt_len = gen_session_initial_prompt(tokenizer, start_input_len, base_seed + session_id)
225223

226224
per_turn: List[Dict] = []
227225
turn_idx = 0
228226
while turn_idx < max_turns and prompt_len < max_input_len:
229227
turn_output_len = rng.randint(min_output_len, output_len)
230-
result = await stream_one_turn(
231-
session, url, model_name, prompt, turn_output_len
232-
)
228+
result = await stream_one_turn(session, url, model_name, prompt, turn_output_len)
233229
if result is None:
234230
break
235231
per_turn.append(result)
@@ -382,28 +378,36 @@ def summarize(
382378

383379
def print_summary(summary: Dict) -> None:
384380
print("=" * 80)
385-
print(f"Concurrency = {summary['concurrency']} sessions = {summary['num_sessions']} "
386-
f"total_turns = {summary['total_turns']} wall_time = {summary['wall_time_s']}s")
381+
print(
382+
f"Concurrency = {summary['concurrency']} sessions = {summary['num_sessions']} "
383+
f"total_turns = {summary['total_turns']} wall_time = {summary['wall_time_s']}s"
384+
)
387385
if "error" in summary:
388386
print(f" ERROR: {summary['error']}")
389387
return
390388
print(f" QPS : {summary['QPS']}")
391389
print(f" TPM (total) : {summary['TPM_total']}")
392390
print(f" TPM (prompt) : {summary['TPM_prompt']}")
393391
print(f" TPM (completion) : {summary['TPM_completion']}")
394-
print(f" Cache hit ratio : {summary['cache_hit_ratio'] * 100:.2f}% "
395-
f"({summary['total_cached_prompt_tokens']} / {summary['total_prompt_tokens']})")
392+
print(
393+
f" Cache hit ratio : {summary['cache_hit_ratio'] * 100:.2f}% "
394+
f"({summary['total_cached_prompt_tokens']} / {summary['total_prompt_tokens']})"
395+
)
396396
print(f" Avg prompt tokens : {summary['avg_prompt_tokens_per_turn']}")
397397
print(f" Avg output tokens : {summary['avg_completion_tokens_per_turn']}")
398398
ttft = summary["TTFT_ms"]
399399
tpot = summary["TPOT_ms"]
400-
print(f" TTFT ms mean={ttft['mean']} P50={ttft.get('P50')} P90={ttft.get('P90')} "
401-
f"P95={ttft.get('P95')} P99={ttft.get('P99')}")
400+
print(
401+
f" TTFT ms mean={ttft['mean']} P50={ttft.get('P50')} P90={ttft.get('P90')} "
402+
f"P95={ttft.get('P95')} P99={ttft.get('P99')}"
403+
)
402404
if tpot.get("mean") is None:
403405
print(f" TPOT ms (n/a: {tpot.get('note')})")
404406
else:
405-
print(f" TPOT ms mean={tpot['mean']} P50={tpot.get('P50')} P90={tpot.get('P90')} "
406-
f"P95={tpot.get('P95')} P99={tpot.get('P99')}")
407+
print(
408+
f" TPOT ms mean={tpot['mean']} P50={tpot.get('P50')} P90={tpot.get('P90')} "
409+
f"P95={tpot.get('P95')} P99={tpot.get('P99')}"
410+
)
407411

408412

409413
def main() -> None:
@@ -413,7 +417,7 @@ def main() -> None:
413417
type=str,
414418
default="http://127.0.0.1:8088/v1/completions",
415419
help="Streaming OpenAI completion endpoint. The benchmark relies on "
416-
"the final SSE `usage` chunk to obtain cached_tokens.",
420+
"the final SSE `usage` chunk to obtain cached_tokens.",
417421
)
418422
parser.add_argument("--tokenizer_path", type=str, required=True)
419423
parser.add_argument(
@@ -428,30 +432,37 @@ def main() -> None:
428432
default="1,4,8,16,32,64,128,256",
429433
help="Comma-separated list of concurrency levels to sweep.",
430434
)
431-
parser.add_argument("--start_input_len", type=int, default=32768,
432-
help="Initial prompt length in tokens per session.")
433-
parser.add_argument("--max_input_len", type=int, default=163840,
434-
help="Stop a session when its prompt exceeds this length.")
435-
parser.add_argument("--turn_input_increment", type=int, default=2048,
436-
help="Maximum new 'user' tokens sampled after each turn, on top "
437-
"of the model's generated text.")
438-
parser.add_argument("--min_turn_input_increment", type=int, default=512,
439-
help="Minimum new 'user' tokens sampled after each turn.")
440-
parser.add_argument("--output_len", type=int, default=512,
441-
help="Maximum max_new_tokens sampled per turn.")
442-
parser.add_argument("--min_output_len", type=int, default=128,
443-
help="Minimum max_new_tokens sampled per turn.")
444-
parser.add_argument("--max_turns", type=int, default=64,
445-
help="Hard cap on turns per session. The session also stops once "
446-
"prompt length reaches --max_input_len.")
435+
parser.add_argument(
436+
"--start_input_len", type=int, default=32768, help="Initial prompt length in tokens per session."
437+
)
438+
parser.add_argument(
439+
"--max_input_len", type=int, default=163840, help="Stop a session when its prompt exceeds this length."
440+
)
441+
parser.add_argument(
442+
"--turn_input_increment",
443+
type=int,
444+
default=2048,
445+
help="Maximum new 'user' tokens sampled after each turn, on top " "of the model's generated text.",
446+
)
447+
parser.add_argument(
448+
"--min_turn_input_increment", type=int, default=512, help="Minimum new 'user' tokens sampled after each turn."
449+
)
450+
parser.add_argument("--output_len", type=int, default=512, help="Maximum max_new_tokens sampled per turn.")
451+
parser.add_argument("--min_output_len", type=int, default=128, help="Minimum max_new_tokens sampled per turn.")
452+
parser.add_argument(
453+
"--max_turns",
454+
type=int,
455+
default=64,
456+
help="Hard cap on turns per session. The session also stops once " "prompt length reaches --max_input_len.",
457+
)
447458
parser.add_argument("--seed", type=int, default=0)
448459
parser.add_argument("--request_timeout_s", type=int, default=3600)
449460
parser.add_argument(
450461
"--dump_file",
451462
type=str,
452463
default="",
453464
help="If set, append the per-concurrency summary dict to this JSON file. "
454-
"If the file already exists and is non-empty, it is read and printed.",
465+
"If the file already exists and is non-empty, it is read and printed.",
455466
)
456467

457468
args = parser.parse_args()

0 commit comments

Comments
 (0)