Skip to content

Commit 80ec38e

Browse files
committed
accommodate non-list metrics in baselines
1 parent 17ac9ae commit 80ec38e

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

sklbench/utils/measurement.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,31 @@ def enrich_metrics(
6565
"""Transforms raw performance and other results into aggregated metrics"""
6666
# time metrics
6767
res = bench_result.copy()
68-
mean, std = box_filter(res["time[ms]"])
69-
if include_performance_stability_metrics:
68+
if isinstance(res["time[ms]"], list):
69+
mean, std = box_filter(res["time[ms]"])
70+
if include_performance_stability_metrics:
71+
res.update(
72+
{
73+
"1st run time[ms]": res["time[ms]"][0],
74+
"1st-mean run ratio": res["time[ms]"][0] / mean,
75+
}
76+
)
7077
res.update(
7178
{
72-
"1st run time[ms]": res["time[ms]"][0],
73-
"1st-mean run ratio": res["time[ms]"][0] / mean,
79+
"time[ms]": mean,
80+
"time CV": std / mean, # Coefficient of Variation
7481
}
7582
)
76-
res.update(
77-
{
78-
"time[ms]": mean,
79-
"time CV": std / mean, # Coefficient of Variation
80-
}
81-
)
83+
else:
84+
# already aggregated (e.g. from a baseline file)
85+
mean = res["time[ms]"]
86+
std = res.get("time std[ms]", 0.0)
87+
if mean != 0:
88+
res["time CV"] = std / mean
89+
else:
90+
res["time CV"] = 0.0
8291
cost = res.get("cost[microdollar]", None)
83-
if cost:
92+
if cost and isinstance(cost, list):
8493
res["cost[microdollar]"] = box_filter(res["cost[microdollar]"])[0]
8594
batch_size = res.get("batch_size", None)
8695
if batch_size:

0 commit comments

Comments
 (0)