Skip to content

Commit 403a5f7

Browse files
committed
docs(benchmark): benchmark executor
1 parent 6395405 commit 403a5f7

1 file changed

Lines changed: 104 additions & 22 deletions

File tree

pysatl_cpd/benchmark/benchmark_executor.py

Lines changed: 104 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,19 @@
1+
# -*- coding: ascii -*-
2+
"""
3+
Benchmark execution module for change-point detection algorithms.
4+
5+
This module provides the core components for running and caching performance
6+
evaluations of online CPD algorithms across multiple datasets and threshold
7+
configurations.
8+
"""
9+
10+
__author__ = "Danil Totmyanin"
11+
__copyright__ = "Copyright (c) 2026 PySATL project"
12+
__license__ = "SPDX-License-Identifier: MIT"
13+
114
import csv
15+
import hashlib
16+
import itertools
217
import math
318
import pickle
419
from collections.abc import Sequence
@@ -14,6 +29,27 @@
1429

1530
@dataclass
1631
class BenchmarkRecord:
32+
"""
33+
Metadata container for a single benchmark execution.
34+
35+
This record uniquely identifies a benchmark run and stores the path
36+
to the cached trace file if disk dumping is enabled.
37+
38+
Parameters
39+
----------
40+
algorithm : str
41+
The string identifier or name of the online algorithm.
42+
configuration_hash : str
43+
A hash string representing the algorithm's configuration.
44+
data : str
45+
The identifier or name of the dataset.
46+
threshold : float
47+
The detection threshold used for this specific run.
48+
trace_path : str | None, default=None
49+
Absolute or relative path to the serialized detection trace file,
50+
if caching is enabled.
51+
"""
52+
1753
algorithm: str
1854
configuration_hash: str
1955
data: str
@@ -22,10 +58,41 @@ class BenchmarkRecord:
2258

2359
@property
2460
def key(self) -> tuple[str, str, str, float]:
61+
"""
62+
Get the unique composite key for this benchmark run.
63+
64+
Returns
65+
-------
66+
tuple[str, str, str, float]
67+
A tuple containing (algorithm, configuration_hash, data, threshold)
68+
used for identifying the record in the registry.
69+
"""
2570
return (self.algorithm, self.configuration_hash, self.data, self.threshold)
2671

2772

2873
class BenchmarkExecutor[DataT]:
74+
"""
75+
Orchestrator for executing change-point detection benchmarks.
76+
77+
Evaluates a set of algorithms across multiple data providers and thresholds
78+
using a provided online solver. Supports a caching mechanism via disk dumping
79+
to prevent redundant calculations on subsequent runs.
80+
81+
Parameters
82+
----------
83+
algorithms : list[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]]
84+
A list of tuples, where each tuple contains an instantiated online
85+
algorithm and a sequence of thresholds to test it against.
86+
providers : list[DataProvider[DataT]]
87+
A list of data providers to be fed into the algorithms.
88+
solver : OnlineCpdSolver
89+
The solver instance responsible for iterating over the data providers
90+
and running the algorithmic logic.
91+
dump_dir : str | Path | None, optional
92+
Directory path where the benchmark registry (CSV) and serialized traces
93+
(Pickle files) should be stored. If None, caching is disabled.
94+
"""
95+
2996
def __init__(
3097
self,
3198
algorithms: list[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
@@ -39,6 +106,21 @@ def __init__(
39106
self.__dump_dir = Path(dump_dir) if dump_dir is not None else None
40107

41108
def execute(self) -> list[tuple[BenchmarkRecord, OnlineDetectionTrace[Any]]]:
109+
"""
110+
Execute the benchmark over all combinations of algorithms, data, and thresholds.
111+
112+
Iterates through the combinations of algorithms, datasets, and thresholds.
113+
If disk caching (`dump_dir`) is enabled, it attempts to load previously
114+
calculated traces from the registry to bypass solver execution. If a trace
115+
is missing, it runs the solver, caches the resulting trace to disk, and
116+
updates the CSV registry.
117+
118+
Returns
119+
-------
120+
list[tuple[BenchmarkRecord, OnlineDetectionTrace[Any]]]
121+
A list of execution results, where each element is a pair containing
122+
the benchmark metadata record and the corresponding detection trace.
123+
"""
42124
results: list[tuple[BenchmarkRecord, OnlineDetectionTrace[Any]]] = []
43125
registry: dict[tuple[str, str, str, float], BenchmarkRecord] = {}
44126
registry_path: Path | None = None
@@ -60,42 +142,42 @@ def execute(self) -> list[tuple[BenchmarkRecord, OnlineDetectionTrace[Any]]]:
60142
)
61143
registry[record.key] = record
62144

63-
for algorithm, thresholds in self.__algorithms:
145+
for (algorithm, thresholds), provider in itertools.product(self.__algorithms, self.__providers):
64146
algo_name = str(algorithm)
65-
config_hash = str(hash(algo_name))
66-
67-
for provider in self.__providers:
68-
data_name = provider.name
147+
config_hash = str(hashlib.md5(algo_name.encode("utf-8")).hexdigest()[:8])
148+
data_name = provider.name
69149

70-
for threshold in thresholds:
71-
key = (algo_name, config_hash, data_name, float(threshold))
150+
for threshold in thresholds:
151+
key = (algo_name, config_hash, data_name, float(threshold))
72152

73-
if key in registry and registry[key].trace_path:
74-
trace_file = Path(registry[key].trace_path) # type: ignore
153+
if key in registry:
154+
cached_path = registry[key].trace_path
155+
if cached_path is not None:
156+
trace_file = Path(cached_path)
75157
if trace_file.exists():
76158
with open(trace_file, "rb") as f:
77159
trace = pickle.load(f)
78160
results.append((registry[key], trace))
79161
continue
80162

81-
steps = list(self.__solver.run(algorithm, provider, threshold))
82-
trace = OnlineDetectionTrace.from_run(steps)
163+
steps = list(self.__solver.run(algorithm, provider, threshold))
164+
trace = OnlineDetectionTrace.from_run(steps)
83165

84-
record = BenchmarkRecord(algo_name, config_hash, data_name, threshold, None)
166+
record = BenchmarkRecord(algo_name, config_hash, data_name, threshold, None)
85167

86-
if self.__dump_dir is not None:
87-
safe_data_name = "".join(c if c.isalnum() else "_" for c in data_name)
88-
thr_str = "inf" if math.isinf(record.threshold) else f"{threshold:.4f}".replace(".", "_")
89-
filename = f"{algo_name}_{config_hash}_{safe_data_name}_{thr_str}.pkl"
168+
if self.__dump_dir is not None:
169+
safe_data_name = "".join(c if c.isalnum() else "_" for c in data_name)
170+
thr_str = "inf" if math.isinf(record.threshold) else f"{threshold:.4f}".replace(".", "_")
171+
filename = f"{algo_name}_{config_hash}_{safe_data_name}_{thr_str}.pkl"
90172

91-
trace_path = self.__dump_dir / filename
92-
with open(trace_path, "wb") as f:
93-
pickle.dump(trace, f)
173+
trace_path = self.__dump_dir / filename
174+
with open(trace_path, "wb") as f:
175+
pickle.dump(trace, f)
94176

95-
record.trace_path = str(trace_path)
96-
registry[key] = record
177+
record.trace_path = str(trace_path)
178+
registry[key] = record
97179

98-
results.append((record, trace))
180+
results.append((record, trace))
99181

100182
if registry_path is not None:
101183
fieldnames = ["algorithm", "configuration_hash", "data", "threshold", "trace_path"]

0 commit comments

Comments
 (0)