Skip to content

Commit 447162a

Browse files
committed
wip
1 parent d10252c commit 447162a

2 files changed

Lines changed: 182 additions & 2 deletions

File tree

.github/scripts/aggregate_benchmarks.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
#!/usr/bin/env python3
2-
"""Aggregate benchmark JSON files by taking the median across runner attempts."""
2+
"""Aggregate benchmark JSON files by taking the median across runner attempts.
3+
4+
The workflow runs the same benchmark suite on multiple independent runners.
5+
This script reads every JSON file produced by those attempts, normalizes the
6+
contained benchmark values, and writes a compact mapping JSON where each value is
7+
the median across attempts.
8+
"""
39

410
from __future__ import annotations
511

@@ -12,6 +18,16 @@
1218

1319

1420
def collect_benchmarks(paths: list[Path]) -> dict[str, list[Benchmark]]:
21+
"""Collect benchmarks from multiple JSON files.
22+
23+
Args:
24+
paths (list[Path]): Paths to hyperfine, pytest-benchmark, or compact
25+
mapping JSON files.
26+
27+
Returns:
28+
dict[str, list[Benchmark]]: Benchmarks grouped by benchmark name.
29+
"""
30+
1531
collected: dict[str, list[Benchmark]] = {}
1632
for path in paths:
1733
for name, benchmark in extract_benchmarks(path).items():
@@ -20,6 +36,18 @@ def collect_benchmarks(paths: list[Path]) -> dict[str, list[Benchmark]]:
2036

2137

2238
def aggregate(collected: dict[str, list[Benchmark]]) -> dict[str, dict[str, object]]:
39+
"""Aggregate grouped benchmarks using the median value.
40+
41+
Args:
42+
collected (dict[str, list[Benchmark]]): Benchmarks grouped by benchmark
43+
name.
44+
45+
Returns:
46+
dict[str, dict[str, object]]: Compact mapping JSON data. Each benchmark
47+
contains ``value``, ``unit``, ``metric``, ``attempts``, and
48+
``attempt_values``.
49+
"""
50+
2351
aggregated: dict[str, dict[str, object]] = {}
2452
for name, benchmarks in sorted(collected.items()):
2553
values = [benchmark.value for benchmark in benchmarks]
@@ -36,6 +64,19 @@ def aggregate(collected: dict[str, list[Benchmark]]) -> dict[str, dict[str, obje
3664

3765

3866
def main_from_paths(input_dir: Path, output: Path) -> int:
67+
"""Aggregate all JSON files in a directory and write the result.
68+
69+
Args:
70+
input_dir (Path): Directory containing benchmark JSON files.
71+
output (Path): Path where the aggregate JSON should be written.
72+
73+
Returns:
74+
int: Always ``0`` on success.
75+
76+
Raises:
77+
ValueError: If no JSON files are found in ``input_dir``.
78+
"""
79+
3980
paths = sorted(input_dir.rglob("*.json"))
4081
if not paths:
4182
raise ValueError(f"No benchmark JSON files found in {input_dir}")
@@ -49,6 +90,12 @@ def main_from_paths(input_dir: Path, output: Path) -> int:
4990

5091

5192
def main() -> int:
93+
"""Run the benchmark aggregation command line interface.
94+
95+
Returns:
96+
int: Always ``0`` on success.
97+
"""
98+
5299
parser = argparse.ArgumentParser()
53100
parser.add_argument("--input-dir", required=True, type=Path)
54101
parser.add_argument("--output", required=True, type=Path)

.github/scripts/compare_benchmarks.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
#!/usr/bin/env python3
2-
"""Compare benchmark JSON files and write a GitHub Actions summary."""
2+
"""Compare benchmark JSON files and write a GitHub Actions summary.
3+
4+
The script supports JSON emitted by hyperfine, JSON emitted by pytest-benchmark,
5+
and a compact mapping format generated by ``aggregate_benchmarks.py``. Timing
6+
formats prefer median values and fall back to mean values when median values are
7+
not present.
8+
"""
39

410
from __future__ import annotations
511

@@ -13,6 +19,15 @@
1319

1420
@dataclass(frozen=True)
1521
class Benchmark:
22+
"""Normalized benchmark result.
23+
24+
Attributes:
25+
name (str): Stable benchmark name used to match baseline and current results.
26+
value (float): Numeric benchmark value used for comparison.
27+
unit (str): Display unit for the value, for example ``"s"``.
28+
metric (str): Source metric name, for example ``"median"`` or ``"mean"``.
29+
"""
30+
1631
name: str
1732
value: float
1833
unit: str
@@ -21,6 +36,18 @@ class Benchmark:
2136

2237
@dataclass(frozen=True)
2338
class Comparison:
39+
"""Comparison between one baseline benchmark and one current benchmark.
40+
41+
Attributes:
42+
name (str): Benchmark name.
43+
baseline (float): Baseline benchmark value.
44+
current (float): Current benchmark value.
45+
delta_percent (float): Percent change from baseline to current.
46+
unit (str): Display unit for both values.
47+
metric (str): Current result metric used for comparison.
48+
regressed (bool): Whether the change exceeds the configured threshold.
49+
"""
50+
2451
name: str
2552
baseline: float
2653
current: float
@@ -31,11 +58,29 @@ class Comparison:
3158

3259

3360
def _read_json(path: Path) -> Any:
61+
"""Read JSON data from a file.
62+
63+
Args:
64+
path (Path): Path to the JSON file.
65+
66+
Returns:
67+
Any: Parsed JSON value.
68+
"""
69+
3470
with path.open("r", encoding="utf-8") as stream:
3571
return json.load(stream)
3672

3773

3874
def _as_float(value: Any) -> float | None:
75+
"""Convert a value to a finite float.
76+
77+
Args:
78+
value (Any): Value to convert.
79+
80+
Returns:
81+
float | None: Converted finite float, or ``None`` if conversion fails.
82+
"""
83+
3984
try:
4085
result = float(value)
4186
except (TypeError, ValueError):
@@ -46,6 +91,15 @@ def _as_float(value: Any) -> float | None:
4691

4792

4893
def _extract_hyperfine(data: dict[str, Any]) -> dict[str, Benchmark]:
94+
"""Extract normalized benchmarks from hyperfine JSON.
95+
96+
Args:
97+
data (dict[str, Any]): Parsed hyperfine JSON object.
98+
99+
Returns:
100+
dict[str, Benchmark]: Benchmarks keyed by command name.
101+
"""
102+
49103
benchmarks: dict[str, Benchmark] = {}
50104
for result in data.get("results", []):
51105
if not isinstance(result, dict):
@@ -62,6 +116,15 @@ def _extract_hyperfine(data: dict[str, Any]) -> dict[str, Benchmark]:
62116

63117

64118
def _extract_pytest_benchmark(data: dict[str, Any]) -> dict[str, Benchmark]:
119+
"""Extract normalized benchmarks from pytest-benchmark JSON.
120+
121+
Args:
122+
data (dict[str, Any]): Parsed pytest-benchmark JSON object.
123+
124+
Returns:
125+
dict[str, Benchmark]: Benchmarks keyed by full benchmark name.
126+
"""
127+
65128
benchmarks: dict[str, Benchmark] = {}
66129
for benchmark in data.get("benchmarks", []):
67130
if not isinstance(benchmark, dict):
@@ -82,6 +145,16 @@ def _extract_pytest_benchmark(data: dict[str, Any]) -> dict[str, Benchmark]:
82145

83146

84147
def _extract_simple_mapping(data: dict[str, Any]) -> dict[str, Benchmark]:
148+
"""Extract normalized benchmarks from a compact mapping JSON object.
149+
150+
Args:
151+
data (dict[str, Any]): Parsed mapping where each benchmark is either a
152+
raw number or an object containing ``value``, ``unit``, and ``metric``.
153+
154+
Returns:
155+
dict[str, Benchmark]: Benchmarks keyed by mapping key.
156+
"""
157+
85158
benchmarks: dict[str, Benchmark] = {}
86159

87160
for name, raw_value in data.items():
@@ -103,6 +176,20 @@ def _extract_simple_mapping(data: dict[str, Any]) -> dict[str, Benchmark]:
103176

104177

105178
def extract_benchmarks(path: Path) -> dict[str, Benchmark]:
179+
"""Extract normalized benchmarks from a supported JSON file.
180+
181+
Args:
182+
path (Path): Path to a hyperfine, pytest-benchmark, or compact mapping
183+
JSON file.
184+
185+
Returns:
186+
dict[str, Benchmark]: Normalized benchmarks keyed by name.
187+
188+
Raises:
189+
ValueError: If the JSON root is not an object or no supported benchmark
190+
entries can be extracted.
191+
"""
192+
106193
data = _read_json(path)
107194
if not isinstance(data, dict):
108195
raise ValueError(f"{path} must contain a JSON object")
@@ -122,6 +209,22 @@ def compare_benchmarks(
122209
threshold_percent: float,
123210
higher_is_better: bool,
124211
) -> tuple[list[Comparison], list[str], list[str]]:
212+
"""Compare baseline benchmarks with current benchmarks.
213+
214+
Args:
215+
baseline (dict[str, Benchmark]): Baseline benchmarks keyed by name.
216+
current (dict[str, Benchmark]): Current benchmarks keyed by name.
217+
threshold_percent (float): Regression threshold in percent.
218+
higher_is_better (bool): If ``True``, lower current values are treated as
219+
regressions. If ``False``, higher current values are treated as
220+
regressions.
221+
222+
Returns:
223+
tuple[list[Comparison], list[str], list[str]]: Comparisons for common
224+
benchmark names, names missing from current results, and names newly
225+
present in current results.
226+
"""
227+
125228
comparisons: list[Comparison] = []
126229
missing_in_current: list[str] = []
127230
new_in_current: list[str] = []
@@ -165,6 +268,16 @@ def compare_benchmarks(
165268

166269

167270
def _format_value(value: float, unit: str) -> str:
271+
"""Format a benchmark value for Markdown output.
272+
273+
Args:
274+
value (float): Numeric benchmark value.
275+
unit (str): Display unit.
276+
277+
Returns:
278+
str: Formatted value with optional unit suffix.
279+
"""
280+
168281
suffix = f" {unit}" if unit else ""
169282
return f"{value:.6g}{suffix}"
170283

@@ -177,6 +290,20 @@ def write_summary(
177290
threshold_percent: float,
178291
higher_is_better: bool,
179292
) -> None:
293+
"""Write a Markdown benchmark comparison summary.
294+
295+
Args:
296+
path (Path): Path where the summary should be written.
297+
comparisons (list[Comparison]): Comparison rows for matching benchmarks.
298+
missing_in_current (list[str]): Baseline benchmark names missing from the
299+
current result.
300+
new_in_current (list[str]): Current benchmark names not present in the
301+
baseline result.
302+
threshold_percent (float): Regression threshold in percent.
303+
higher_is_better (bool): Whether higher benchmark values are considered
304+
better.
305+
"""
306+
180307
regressions = [comparison for comparison in comparisons if comparison.regressed]
181308
direction = "higher is better" if higher_is_better else "lower is better"
182309
sorted_comparisons = sorted(comparisons, key=lambda comparison: comparison.name)
@@ -245,6 +372,12 @@ def write_summary(
245372

246373

247374
def main() -> int:
375+
"""Run the benchmark comparison command line interface.
376+
377+
Returns:
378+
int: ``1`` when a regression exceeds the threshold, otherwise ``0``.
379+
"""
380+
248381
parser = argparse.ArgumentParser()
249382
parser.add_argument("--baseline", required=True, type=Path)
250383
parser.add_argument("--current", required=True, type=Path)

0 commit comments

Comments
 (0)