@@ -18,7 +18,9 @@ class RunMetrics:
1818 """Metrics from a single run."""
1919
2020 generated_tokens : int
21+ prompt_tokens : int
2122 tokens_per_sec : float
23+ prefill_tokens_per_sec : float
2224 model_load_time_ms : float
2325 total_inference_time_ms : float
2426 encoder_time_ms : float
@@ -28,7 +30,8 @@ class RunMetrics:
2830 def __repr__ (self ):
2931 return (
3032 f"Tokens: { self .generated_tokens } , "
31- f"Throughput: { self .tokens_per_sec :.2f} t/s, "
33+ f"Prefill: { self .prefill_tokens_per_sec :.2f} t/s ({ self .prompt_tokens } tokens), "
34+ f"Decode: { self .tokens_per_sec :.2f} t/s, "
3235 f"Model load: { self .model_load_time_ms :.0f} ms, "
3336 f"Total inference: { self .total_inference_time_ms :.0f} ms, "
3437 f"Encoder: { self .encoder_time_ms :.0f} ms, "
@@ -49,6 +52,7 @@ def parse_pytorch_observer_log(log_line: str) -> Optional[RunMetrics]:
4952
5053 # Extract values
5154 generated_tokens = data .get ("generated_tokens" , 0 )
55+ prompt_tokens = data .get ("prompt_tokens" , 0 )
5256 inference_start_ms = data .get ("inference_start_ms" , 0 )
5357 inference_end_ms = data .get ("inference_end_ms" , 0 )
5458 prompt_eval_end_ms = data .get ("prompt_eval_end_ms" , 0 )
@@ -72,12 +76,20 @@ def parse_pytorch_observer_log(log_line: str) -> Optional[RunMetrics]:
7276 if generation_time_ms > 0
7377 else 0
7478 )
79+
80+ # Calculate prefill throughput
81+ prefill_tokens_per_sec = (
82+ (prompt_tokens / encoder_time_ms * 1000 ) if encoder_time_ms > 0 else 0
83+ )
84+
7585 model_load_time_ms = model_load_end_ms - model_load_start_ms
7686 first_token_latency_ms = first_token_ms - prompt_eval_end_ms
7787
7888 return RunMetrics (
7989 generated_tokens = generated_tokens ,
90+ prompt_tokens = prompt_tokens ,
8091 tokens_per_sec = tokens_per_sec ,
92+ prefill_tokens_per_sec = prefill_tokens_per_sec ,
8193 model_load_time_ms = model_load_time_ms ,
8294 total_inference_time_ms = total_inference_time_ms ,
8395 encoder_time_ms = encoder_time_ms ,
@@ -505,6 +517,7 @@ class BenchmarkResults:
505517
506518 # Metrics
507519 throughput : MetricStats
520+ prefill_throughput : MetricStats
508521 model_load_time : MetricStats
509522 total_inference_time : MetricStats
510523 encoder_time : MetricStats
@@ -529,6 +542,10 @@ def to_dict(self) -> dict:
529542 "throughput_min" : self .throughput .min_val ,
530543 "throughput_max" : self .throughput .max_val ,
531544 "throughput_stdev" : self .throughput .stdev ,
545+ "prefill_throughput_mean" : self .prefill_throughput .mean ,
546+ "prefill_throughput_min" : self .prefill_throughput .min_val ,
547+ "prefill_throughput_max" : self .prefill_throughput .max_val ,
548+ "prefill_throughput_stdev" : self .prefill_throughput .stdev ,
532549 "model_load_time_mean" : self .model_load_time .mean ,
533550 "model_load_time_min" : self .model_load_time .min_val ,
534551 "model_load_time_max" : self .model_load_time .max_val ,
@@ -601,6 +618,13 @@ def to_v3_format(
601618 runner_type ,
602619 base_extra_info ,
603620 ),
621+ self .prefill_throughput .create_v3_record (
622+ model_name_with_quant ,
623+ backend ,
624+ runner_name ,
625+ runner_type ,
626+ base_extra_info ,
627+ ),
604628 self .model_load_time .create_v3_record (
605629 model_name_with_quant ,
606630 backend ,
@@ -696,6 +720,11 @@ def create_metric_stats(
696720 "t/s" ,
697721 {"trimmed_runs" : len (trimmed_throughput )},
698722 ),
723+ prefill_throughput = create_metric_stats (
724+ "prefill_encoder_throughput(tokens/sec)" ,
725+ [r .prefill_tokens_per_sec for r in results ],
726+ "t/s" ,
727+ ),
699728 model_load_time = create_metric_stats (
700729 "model_load_time(ms)" ,
701730 [r .model_load_time_ms for r in results ],
@@ -740,6 +769,7 @@ def print_summary(summary: BenchmarkResults) -> None:
740769
741770 # Print all metrics using their print_stats method
742771 summary .throughput .print_stats ()
772+ summary .prefill_throughput .print_stats ()
743773 summary .model_load_time .print_stats ()
744774 summary .total_inference_time .print_stats ()
745775 summary .encoder_time .print_stats ()
0 commit comments