Skip to content

Commit d5a4c01

Browse files
committed
Add unit and e2e test
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent adcf656 commit d5a4c01

6 files changed

Lines changed: 884 additions & 59 deletions

File tree

examples/evaluate/perf-benchmark/scripts/fetch_vllm_metrics.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414

1515
import argparse
1616
import json
17+
import logging
1718
import re
1819
import sys
1920
from dataclasses import dataclass
2021
from pathlib import Path
21-
from typing import Any
2222

2323
import requests
2424

25+
logger = logging.getLogger(__name__)
26+
2527

2628
@dataclass
2729
class Metric:
@@ -52,26 +54,26 @@ def fetch_metrics(url: str, timeout: int = 10) -> str:
5254
response.raise_for_status()
5355
return response.text
5456
except requests.exceptions.RequestException as e:
55-
print(f"[ERROR] Failed to fetch metrics from {metrics_url}: {e}", file=sys.stderr)
57+
logger.error("Failed to fetch metrics from %s: %s", metrics_url, e)
5658
sys.exit(1)
5759

5860

59-
def parse_prometheus_metrics(raw_text: str) -> list[Metric]:
61+
def parse_prometheus_metrics(raw_text: str) -> list[Metric]: # noqa: C901
6062
"""Parse Prometheus-formatted metrics into Metric objects."""
6163
metrics: list[Metric] = []
6264
lines = raw_text.strip().split("\n")
6365

6466
# Track vector metrics (those with position labels)
6567
vector_data: dict[str, list[tuple[int, float]]] = {}
6668

67-
for line in lines:
68-
line = line.strip()
69+
for raw_line in lines:
70+
line = raw_line.strip()
6971
if not line or line.startswith("#"):
7072
continue
7173

7274
# Match metric lines like: metric_name{labels} value
7375
# or simple: metric_name value
74-
match = re.match(r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([0-9.eE+-]+)$', line)
76+
match = re.match(r"^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([0-9.eE+-]+)$", line)
7577
if match:
7678
name, value = match.groups()
7779
if "vllm:spec_decode" in name:
@@ -80,7 +82,7 @@ def parse_prometheus_metrics(raw_text: str) -> list[Metric]:
8082

8183
# Match labeled metrics like: metric_name{label="value"} value
8284
match = re.match(
83-
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]+)\}\s+([0-9.eE+-]+)$', line
85+
r"^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]+)\}\s+([0-9.eE+-]+)$", line
8486
)
8587
if match:
8688
name, labels_str, value = match.groups()
@@ -105,7 +107,9 @@ def parse_prometheus_metrics(raw_text: str) -> list[Metric]:
105107
return metrics
106108

107109

108-
def extract_metrics(raw_metrics: list[Metric], total_num_output_tokens: int) -> dict:
110+
def extract_metrics( # noqa: C901
111+
raw_metrics: list[Metric], total_num_output_tokens: int
112+
) -> dict:
109113
"""Extract speculative decoding metrics and calculate acceptance rates."""
110114
metrics_dict: dict[str, int | float] = {}
111115
num_drafts = 0
@@ -115,16 +119,12 @@ def extract_metrics(raw_metrics: list[Metric], total_num_output_tokens: int) ->
115119

116120
for metric in raw_metrics:
117121
if metric.name == "vllm:spec_decode_num_drafts":
118-
assert isinstance(metric, Counter)
119122
num_drafts += metric.value
120123
elif metric.name == "vllm:spec_decode_num_draft_tokens":
121-
assert isinstance(metric, Counter)
122124
num_draft_tokens += metric.value
123125
elif metric.name == "vllm:spec_decode_num_accepted_tokens":
124-
assert isinstance(metric, Counter)
125126
num_accepted_tokens += metric.value
126127
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
127-
assert isinstance(metric, Vector)
128128
if len(acceptance_counts) < len(metric.values):
129129
acceptance_counts = acceptance_counts + [0.0] * (
130130
len(metric.values) - len(acceptance_counts)
@@ -154,7 +154,9 @@ def format_output(metrics: dict[str, int | float]) -> str:
154154
lines.append(f"Number of drafts: {metrics.get('num_drafts', 0)}")
155155
lines.append(f"Draft tokens proposed: {metrics.get('num_draft_tokens', 0)}")
156156
lines.append(f"Draft tokens accepted: {metrics.get('num_accepted_tokens', 0)}")
157-
lines.append(f"Average acceptance length: {metrics.get('acceptance_length', 0):.2f}")
157+
lines.append(
158+
f"Average acceptance length: {metrics.get('acceptance_length', 0):.2f}"
159+
)
158160
lines.append("\nPer-position acceptance rates:")
159161

160162
pos = 0
@@ -167,6 +169,13 @@ def format_output(metrics: dict[str, int | float]) -> str:
167169

168170

169171
def main() -> None:
172+
# Configure logging for CLI usage
173+
logging.basicConfig(
174+
level=logging.INFO,
175+
format="[%(levelname)s] %(message)s",
176+
stream=sys.stderr,
177+
)
178+
170179
parser = argparse.ArgumentParser(
171180
description="Fetch vLLM metrics and calculate acceptance rates."
172181
)
@@ -197,24 +206,24 @@ def main() -> None:
197206
args = parser.parse_args()
198207

199208
# Fetch raw metrics
200-
print(f"[INFO] Fetching metrics from {args.url}/metrics")
209+
logger.info("Fetching metrics from %s/metrics", args.url)
201210
raw_text = fetch_metrics(args.url, timeout=args.timeout)
202211

203212
# Parse metrics
204213
parsed_metrics = parse_prometheus_metrics(raw_text)
205-
print(f"[INFO] Parsed {len(parsed_metrics)} speculative decoding metrics")
214+
logger.info("Parsed %d speculative decoding metrics", len(parsed_metrics))
206215

207216
# Extract and calculate
208217
result = extract_metrics(parsed_metrics, args.total_tokens)
209218

210219
# Output
211-
print(format_output(result))
220+
logger.info("\n%s", format_output(result))
212221

213222
if args.output:
214223
args.output.parent.mkdir(parents=True, exist_ok=True)
215224
with args.output.open("w") as f:
216225
json.dump(result, f, indent=2)
217-
print(f"\n[INFO] Metrics saved to: {args.output}")
226+
logger.info("Metrics saved to: %s", args.output)
218227

219228

220229
if __name__ == "__main__":

0 commit comments

Comments
 (0)