1414
1515import argparse
1616import json
17+ import logging
1718import re
1819import sys
1920from dataclasses import dataclass
2021from pathlib import Path
21- from typing import Any
2222
2323import requests
2424
25+ logger = logging .getLogger (__name__ )
26+
2527
2628@dataclass
2729class Metric :
@@ -52,26 +54,26 @@ def fetch_metrics(url: str, timeout: int = 10) -> str:
5254 response .raise_for_status ()
5355 return response .text
5456 except requests .exceptions .RequestException as e :
55- print ( f"[ERROR] Failed to fetch metrics from { metrics_url } : { e } " , file = sys . stderr )
57+ logger . error ( " Failed to fetch metrics from %s: %s " , metrics_url , e )
5658 sys .exit (1 )
5759
5860
59- def parse_prometheus_metrics (raw_text : str ) -> list [Metric ]:
61+ def parse_prometheus_metrics (raw_text : str ) -> list [Metric ]: # noqa: C901
6062 """Parse Prometheus-formatted metrics into Metric objects."""
6163 metrics : list [Metric ] = []
6264 lines = raw_text .strip ().split ("\n " )
6365
6466 # Track vector metrics (those with position labels)
6567 vector_data : dict [str , list [tuple [int , float ]]] = {}
6668
67- for line in lines :
68- line = line .strip ()
69+ for raw_line in lines :
70+ line = raw_line .strip ()
6971 if not line or line .startswith ("#" ):
7072 continue
7173
7274 # Match metric lines like: metric_name{labels} value
7375 # or simple: metric_name value
74- match = re .match (r' ^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([0-9.eE+-]+)$' , line )
76+ match = re .match (r" ^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([0-9.eE+-]+)$" , line )
7577 if match :
7678 name , value = match .groups ()
7779 if "vllm:spec_decode" in name :
@@ -80,7 +82,7 @@ def parse_prometheus_metrics(raw_text: str) -> list[Metric]:
8082
8183 # Match labeled metrics like: metric_name{label="value"} value
8284 match = re .match (
83- r' ^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]+)\}\s+([0-9.eE+-]+)$' , line
85+ r" ^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]+)\}\s+([0-9.eE+-]+)$" , line
8486 )
8587 if match :
8688 name , labels_str , value = match .groups ()
@@ -105,7 +107,9 @@ def parse_prometheus_metrics(raw_text: str) -> list[Metric]:
105107 return metrics
106108
107109
108- def extract_metrics (raw_metrics : list [Metric ], total_num_output_tokens : int ) -> dict :
110+ def extract_metrics ( # noqa: C901
111+ raw_metrics : list [Metric ], total_num_output_tokens : int
112+ ) -> dict :
109113 """Extract speculative decoding metrics and calculate acceptance rates."""
110114 metrics_dict : dict [str , int | float ] = {}
111115 num_drafts = 0
@@ -115,16 +119,12 @@ def extract_metrics(raw_metrics: list[Metric], total_num_output_tokens: int) ->
115119
116120 for metric in raw_metrics :
117121 if metric .name == "vllm:spec_decode_num_drafts" :
118- assert isinstance (metric , Counter )
119122 num_drafts += metric .value
120123 elif metric .name == "vllm:spec_decode_num_draft_tokens" :
121- assert isinstance (metric , Counter )
122124 num_draft_tokens += metric .value
123125 elif metric .name == "vllm:spec_decode_num_accepted_tokens" :
124- assert isinstance (metric , Counter )
125126 num_accepted_tokens += metric .value
126127 elif metric .name == "vllm:spec_decode_num_accepted_tokens_per_pos" :
127- assert isinstance (metric , Vector )
128128 if len (acceptance_counts ) < len (metric .values ):
129129 acceptance_counts = acceptance_counts + [0.0 ] * (
130130 len (metric .values ) - len (acceptance_counts )
@@ -154,7 +154,9 @@ def format_output(metrics: dict[str, int | float]) -> str:
154154 lines .append (f"Number of drafts: { metrics .get ('num_drafts' , 0 )} " )
155155 lines .append (f"Draft tokens proposed: { metrics .get ('num_draft_tokens' , 0 )} " )
156156 lines .append (f"Draft tokens accepted: { metrics .get ('num_accepted_tokens' , 0 )} " )
157- lines .append (f"Average acceptance length: { metrics .get ('acceptance_length' , 0 ):.2f} " )
157+ lines .append (
158+ f"Average acceptance length: { metrics .get ('acceptance_length' , 0 ):.2f} "
159+ )
158160 lines .append ("\n Per-position acceptance rates:" )
159161
160162 pos = 0
@@ -167,6 +169,13 @@ def format_output(metrics: dict[str, int | float]) -> str:
167169
168170
169171def main () -> None :
172+ # Configure logging for CLI usage
173+ logging .basicConfig (
174+ level = logging .INFO ,
175+ format = "[%(levelname)s] %(message)s" ,
176+ stream = sys .stderr ,
177+ )
178+
170179 parser = argparse .ArgumentParser (
171180 description = "Fetch vLLM metrics and calculate acceptance rates."
172181 )
@@ -197,24 +206,24 @@ def main() -> None:
197206 args = parser .parse_args ()
198207
199208 # Fetch raw metrics
200- print ( f"[INFO] Fetching metrics from { args . url } /metrics" )
209+ logger . info ( " Fetching metrics from %s /metrics", args . url )
201210 raw_text = fetch_metrics (args .url , timeout = args .timeout )
202211
203212 # Parse metrics
204213 parsed_metrics = parse_prometheus_metrics (raw_text )
205- print ( f"[INFO] Parsed { len ( parsed_metrics ) } speculative decoding metrics" )
214+ logger . info ( " Parsed %d speculative decoding metrics", len ( parsed_metrics ) )
206215
207216 # Extract and calculate
208217 result = extract_metrics (parsed_metrics , args .total_tokens )
209218
210219 # Output
211- print ( format_output (result ))
220+ logger . info ( " \n %s" , format_output (result ))
212221
213222 if args .output :
214223 args .output .parent .mkdir (parents = True , exist_ok = True )
215224 with args .output .open ("w" ) as f :
216225 json .dump (result , f , indent = 2 )
217- print ( f" \n [INFO] Metrics saved to: { args .output } " )
226+ logger . info ( " Metrics saved to: %s" , args .output )
218227
219228
220229if __name__ == "__main__" :
0 commit comments