Skip to content

Commit 24d35be

Browse files
committed
feat: add abstract OnlineBenchmarkRunner
1 parent 6803c09 commit 24d35be

7 files changed

Lines changed: 725 additions & 21 deletions

File tree

pysatl_cpd/benchmark/arl_benchmark_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Sequence
12
from pathlib import Path
23
from typing import Any
34

@@ -13,7 +14,7 @@ class ARLBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledDa
1314
):
1415
def __init__(
1516
self,
16-
algorithms: list[tuple[OnlineAlgorithm[Any, Any, Any], list[float]]],
17+
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
1718
providers: list[ProviderT],
1819
solver: OnlineCpdSolver,
1920
dump_dir: Path | None = None,
@@ -24,6 +25,6 @@ def _collect_runs(
2425
self,
2526
algorithm: OnlineAlgorithm[Any, Any, Any],
2627
threshold: float,
27-
providers: list[ProviderT],
28+
providers: Sequence[ProviderT],
2829
) -> list[tuple[TraceT, ProviderT]]:
2930
raise NotImplementedError("Method `_collect_runs` is not implemented yet.")

pysatl_cpd/benchmark/core/benchmark_executor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,11 @@ class BenchmarkExecutor[DataT]:
7979
8080
Parameters
8181
----------
82-
algorithms : list[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]]
83-
A list of tuples, where each tuple contains an instantiated online
82+
algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]]
83+
A sequence of tuples, where each tuple contains an instantiated online
8484
algorithm and a sequence of thresholds to test it against.
85-
providers : list[DataProvider[DataT]]
86-
A list of data providers to be fed into the algorithms.
85+
providers : Sequence[DataProvider[DataT]]
86+
A sequence of data providers to be fed into the algorithms.
8787
solver : OnlineCpdSolver
8888
The solver instance responsible for iterating over the data providers
8989
and running the algorithmic logic.
@@ -94,8 +94,8 @@ class BenchmarkExecutor[DataT]:
9494

9595
def __init__(
9696
self,
97-
algorithms: list[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
98-
providers: list[DataProvider[DataT]],
97+
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
98+
providers: Sequence[DataProvider[DataT]],
9999
solver: OnlineCpdSolver,
100100
dump_dir: str | Path | None = None,
101101
) -> None:

pysatl_cpd/benchmark/noreset/noreset_benchmark_runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Sequence
12
from pathlib import Path
23
from typing import Any
34

@@ -14,8 +15,8 @@
1415
class NoResetBenchmarkRunner[ProviderT: LabeledData[Any]](OnlineBenchmarkRunner[NoResetDetectionTrace[Any], ProviderT]):
1516
def __init__(
1617
self,
17-
algorithms: list[tuple[OnlineAlgorithm[Any, Any, Any], list[float]]],
18-
providers: list[ProviderT],
18+
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
19+
providers: Sequence[ProviderT],
1920
metrics: dict[str, MultipleRunMetric[NoResetDetectionTrace[Any], ProviderT, Any]],
2021
solver: OnlineCpdSolver,
2122
policy: ThresholdPolicy,
@@ -27,7 +28,7 @@ def _collect_runs(
2728
self,
2829
algorithm: OnlineAlgorithm[Any, Any, Any],
2930
threshold: float,
30-
providers: list[ProviderT],
31+
providers: Sequence[ProviderT],
3132
) -> list[tuple[NoResetDetectionTrace[Any], ProviderT]]:
3233
raise NotImplementedError("Method '_collect_runs' is not implemented yet.")
3334

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
1-
# online_runner.py
1+
# -*- coding: ascii -*-
2+
3+
"""
4+
Abstract base class for online benchmark runners.
5+
"""
6+
7+
__author__ = "Danil Totmyanin"
8+
__copyright__ = "Copyright (c) 2026 PySATL project"
9+
__license__ = "SPDX-License-Identifier: MIT"
10+
211
from abc import ABC, abstractmethod
12+
from collections.abc import Sequence
313
from pathlib import Path
414
from typing import Any
515

@@ -11,26 +21,102 @@
1121

1222

1323
class OnlineBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledData[Any]](ABC):
24+
"""
25+
Abstract base class for online benchmark runners.
26+
27+
Organises the evaluation loop over algorithms and thresholds,
28+
delegates data collection to subclasses via _collect_runs(), and
29+
applies all registered metrics to each batch of runs.
30+
31+
Parameters
32+
----------
33+
algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]]
34+
Sequence of (algorithm, thresholds) pairs to evaluate.
35+
providers : Sequence[ProviderT]
36+
Sequence of labeled data providers.
37+
metrics : dict[str, MultipleRunMetric[TraceT, ProviderT, Any]]
38+
Named metrics to evaluate for each (algorithm, threshold) batch.
39+
solver : OnlineCpdSolver
40+
Solver used to run algorithms against providers.
41+
dump_dir : Path | str | None, optional
42+
Directory for caching results via BenchmarkExecutor.
43+
If None, caching is disabled. Default is None.
44+
"""
45+
1446
def __init__(
1547
self,
16-
algorithms: list[tuple[OnlineAlgorithm[Any, Any, Any], list[float]]],
17-
providers: list[ProviderT],
48+
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
49+
providers: Sequence[ProviderT],
1850
metrics: dict[str, MultipleRunMetric[TraceT, ProviderT, Any]],
1951
solver: OnlineCpdSolver,
20-
dump_dir: Path | None = None,
52+
dump_dir: Path | str | None = None,
2153
) -> None:
22-
return
54+
self._algorithms = algorithms
55+
self._providers = providers
56+
self._metrics = metrics
57+
self._solver = solver
58+
self._dump_dir = Path(dump_dir) if isinstance(dump_dir, str) else dump_dir
2359

2460
@abstractmethod
2561
def _collect_runs(
2662
self,
2763
algorithm: OnlineAlgorithm[Any, Any, Any],
2864
threshold: float,
29-
providers: list[ProviderT],
65+
providers: Sequence[ProviderT],
3066
) -> list[tuple[TraceT, ProviderT]]:
67+
"""
68+
Collect (trace, provider) pairs for a given algorithm and threshold.
69+
70+
Parameters
71+
----------
72+
algorithm : OnlineAlgorithm[Any, Any, Any]
73+
The algorithm to evaluate.
74+
threshold : float
75+
The detection threshold.
76+
providers : Sequence[ProviderT]
77+
Sequence of data providers to run against.
78+
79+
Returns
80+
-------
81+
list[tuple[TraceT, ProviderT]]
82+
Batch of (trace, provider) pairs for metric evaluation.
83+
"""
84+
3185
raise NotImplementedError("Method `_collect_runs` is not implemented yet.")
3286

3387
def run(
3488
self,
3589
) -> dict[tuple[str, OnlineAlgorithmConfiguration], list[tuple[float, dict[str, Any]]]]:
36-
raise NotImplementedError("Method `run` is not implemented yet.")
90+
"""
91+
Execute the benchmark over all algorithms and thresholds.
92+
93+
For each (algorithm, threshold) pair, collects runs via
94+
_collect_runs() and evaluates all registered metrics.
95+
96+
Returns
97+
-------
98+
dict[tuple[str, OnlineAlgorithmConfiguration], list[tuple[float, dict[str, Any]]]]
99+
Mapping of (algorithm_name, configuration) to a list of
100+
(threshold, {metric_name: metric_value}) entries, one per threshold.
101+
"""
102+
103+
results: dict[
104+
tuple[str, OnlineAlgorithmConfiguration],
105+
list[tuple[float, dict[str, Any]]],
106+
] = {}
107+
108+
for algorithm, thresholds in self._algorithms:
109+
key: tuple[str, OnlineAlgorithmConfiguration] = (
110+
str(algorithm),
111+
algorithm.configuration,
112+
)
113+
results[key] = []
114+
115+
for threshold in thresholds:
116+
runs = self._collect_runs(algorithm, threshold, self._providers)
117+
118+
metric_values: dict[str, Any] = {name: metric.evaluate(runs) for name, metric in self._metrics.items()}
119+
120+
results[key].append((threshold, metric_values))
121+
122+
return results

pysatl_cpd/benchmark/reset_benchmark_runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Sequence
12
from pathlib import Path
23
from typing import Any
34

@@ -14,8 +15,8 @@ class ResetBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: Labeled
1415
):
1516
def __init__(
1617
self,
17-
algorithms: list[tuple[OnlineAlgorithm[Any, Any, Any], list[float]]],
18-
providers: list[ProviderT],
18+
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
19+
providers: Sequence[ProviderT],
1920
metrics: dict[str, MultipleRunMetric[TraceT, ProviderT, Any]],
2021
solver: OnlineCpdSolver,
2122
dump_dir: Path | None = None,
@@ -26,6 +27,6 @@ def _collect_runs(
2627
self,
2728
algorithm: OnlineAlgorithm[Any, Any, Any],
2829
threshold: float,
29-
providers: list[ProviderT],
30+
providers: Sequence[ProviderT],
3031
) -> list[tuple[TraceT, ProviderT]]:
3132
raise NotImplementedError("Method `_collect_runs` is not implemented yet.")
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# -*- coding: ascii -*-
2+
3+
"""
4+
Mock OnlineBenchmarkRunner for testing.
5+
"""
6+
7+
__author__ = "Danil Totmyanin"
8+
__copyright__ = "Copyright (c) 2026 PySATL project"
9+
__license__ = "SPDX-License-Identifier: MIT"
10+
11+
from collections.abc import Sequence
12+
from pathlib import Path
13+
from typing import Any
14+
15+
from pysatl_cpd.analysis.labeled_data import LabeledData
16+
from pysatl_cpd.benchmark.metrics.multiple_run_metric import MultipleRunMetric
17+
from pysatl_cpd.benchmark.online_benchmark_runner import OnlineBenchmarkRunner
18+
from pysatl_cpd.core.online.ionline_algorithm import OnlineAlgorithm
19+
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
20+
from pysatl_cpd.core.online.online_detection_trace import OnlineDetectionTrace
21+
22+
23+
class MockBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledData[Any]](
24+
OnlineBenchmarkRunner[TraceT, ProviderT]
25+
):
26+
"""
27+
Mock implementation of OnlineBenchmarkRunner for testing.
28+
29+
Records all _collect_runs calls for assertion in tests.
30+
Returns a pre-configured list of runs for each call.
31+
32+
Parameters
33+
----------
34+
algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]]
35+
Sequence of (algorithm, thresholds) pairs.
36+
providers : Sequence[ProviderT]
37+
Sequence of data providers.
38+
metrics : dict[str, MultipleRunMetric[TraceT, ProviderT, Any]]
39+
Dictionary of metrics to evaluate.
40+
solver : OnlineCpdSolver
41+
Solver instance.
42+
dump_dir : Path | str | None, optional
43+
Directory for caching results.
44+
runs_to_return : list[tuple[TraceT, ProviderT]] | None, optional
45+
Pre-configured runs returned by _collect_runs.
46+
If None, returns empty list.
47+
"""
48+
49+
def __init__(
50+
self,
51+
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
52+
providers: Sequence[ProviderT],
53+
metrics: dict[str, MultipleRunMetric[TraceT, ProviderT, Any]],
54+
solver: OnlineCpdSolver,
55+
dump_dir: Path | str | None = None,
56+
runs_to_return: list[tuple[TraceT, ProviderT]] | None = None,
57+
) -> None:
58+
super().__init__(
59+
algorithms=algorithms,
60+
providers=providers,
61+
metrics=metrics,
62+
solver=solver,
63+
dump_dir=dump_dir,
64+
)
65+
self._runs_to_return: list[tuple[TraceT, ProviderT]] = runs_to_return or []
66+
self.collect_runs_calls: list[tuple[OnlineAlgorithm[Any, Any, Any], float, Sequence[ProviderT]]] = []
67+
68+
def _collect_runs(
69+
self,
70+
algorithm: OnlineAlgorithm[Any, Any, Any],
71+
threshold: float,
72+
providers: Sequence[ProviderT],
73+
) -> list[tuple[TraceT, ProviderT]]:
74+
"""
75+
Record the call and return pre-configured runs.
76+
77+
Parameters
78+
----------
79+
algorithm : OnlineAlgorithm[Any, Any, Any]
80+
The algorithm being evaluated.
81+
threshold : float
82+
The detection threshold.
83+
providers : Sequence[ProviderT]
84+
Sequence of data providers.
85+
86+
Returns
87+
-------
88+
list[tuple[TraceT, ProviderT]]
89+
Pre-configured runs set at construction time.
90+
"""
91+
self.collect_runs_calls.append((algorithm, threshold, providers))
92+
return self._runs_to_return

0 commit comments

Comments
 (0)