Skip to content

Commit ef318da

Browse files
committed
feat: add ResetBenchmarkRunner
1 parent 24d35be commit ef318da

2 files changed

Lines changed: 581 additions & 4 deletions

File tree

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
1+
# -*- coding: ascii -*-
2+
3+
"""
4+
Reset benchmark runner implementation.
5+
6+
This module provides ResetBenchmarkRunner - a benchmark that runs the
7+
solver normally, resetting the algorithm on every detected change point.
8+
Results are cached via BenchmarkExecutor.
9+
"""
10+
11+
__author__ = "Danil Totmyanin"
12+
__copyright__ = "Copyright (c) 2026 PySATL project"
13+
__license__ = "SPDX-License-Identifier: MIT"
14+
115
from collections.abc import Sequence
216
from pathlib import Path
3-
from typing import Any
17+
from typing import Any, cast
418

519
from pysatl_cpd.analysis.labeled_data import LabeledData
20+
from pysatl_cpd.benchmark.core.benchmark_executor import BenchmarkExecutor
621
from pysatl_cpd.benchmark.metrics.multiple_run_metric import MultipleRunMetric
722
from pysatl_cpd.benchmark.online_benchmark_runner import OnlineBenchmarkRunner
823
from pysatl_cpd.core.online.ionline_algorithm import OnlineAlgorithm
@@ -13,20 +28,91 @@
1328
class ResetBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledData[Any]](
1429
OnlineBenchmarkRunner[TraceT, ProviderT]
1530
):
31+
"""
32+
Benchmark runner that uses standard reset behaviour.
33+
34+
For each (algorithm, threshold) pair, runs the solver over all
35+
providers via BenchmarkExecutor. The algorithm is reset on every
36+
detected change point (standard solver behaviour). Results are
37+
cached to disk when dump_dir is provided.
38+
39+
Parameters
40+
----------
41+
algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]]
42+
Sequence of (algorithm, thresholds) pairs to evaluate.
43+
providers : Sequence[ProviderT]
44+
Labeled data providers to run against.
45+
metrics : dict[str, MultipleRunMetric[TraceT, ProviderT, Any]]
46+
Named metrics to evaluate for each (algorithm, threshold) batch.
47+
solver : OnlineCpdSolver
48+
Solver used to run algorithms against providers.
49+
dump_dir : Path | str | None, optional
50+
Directory for caching results via BenchmarkExecutor.
51+
If None, caching is disabled. Default is None.
52+
"""
53+
1654
def __init__(
1755
self,
1856
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
1957
providers: Sequence[ProviderT],
2058
metrics: dict[str, MultipleRunMetric[TraceT, ProviderT, Any]],
2159
solver: OnlineCpdSolver,
22-
dump_dir: Path | None = None,
60+
dump_dir: Path | str | None = None,
2361
) -> None:
24-
return
62+
super().__init__(
63+
algorithms=algorithms,
64+
providers=providers,
65+
metrics=metrics,
66+
solver=solver,
67+
dump_dir=dump_dir,
68+
)
2569

2670
def _collect_runs(
2771
self,
2872
algorithm: OnlineAlgorithm[Any, Any, Any],
2973
threshold: float,
3074
providers: Sequence[ProviderT],
3175
) -> list[tuple[TraceT, ProviderT]]:
32-
raise NotImplementedError("Method `_collect_runs` is not implemented yet.")
76+
"""
77+
Collect runs for a given algorithm and threshold via BenchmarkExecutor.
78+
79+
Creates a BenchmarkExecutor with a single threshold and all providers,
80+
executes it, and pairs each resulting trace with its provider.
81+
82+
Parameters
83+
----------
84+
algorithm : OnlineAlgorithm[Any, Any, Any]
85+
The algorithm to evaluate.
86+
threshold : float
87+
The detection threshold.
88+
providers : Sequence[ProviderT]
89+
Data providers to run against.
90+
91+
Returns
92+
-------
93+
list[tuple[TraceT, ProviderT]]
94+
List of (trace, provider) pairs, one per provider.
95+
"""
96+
if not providers:
97+
return []
98+
99+
executor: BenchmarkExecutor[Any] = BenchmarkExecutor(
100+
algorithms=[(algorithm, [threshold])],
101+
providers=list(providers),
102+
solver=self._solver,
103+
dump_dir=self._dump_dir,
104+
)
105+
106+
records_and_traces = executor.execute()
107+
108+
# BenchmarkExecutor returns (BenchmarkRecord, OnlineDetectionTrace) pairs.
109+
# We need to pair each trace with the correct provider.
110+
# Executor iterates providers in the same order as input.
111+
provider_by_name: dict[str, ProviderT] = {provider.name: provider for provider in providers}
112+
113+
runs: list[tuple[TraceT, ProviderT]] = []
114+
for record, trace in records_and_traces:
115+
provider = provider_by_name[record.data]
116+
runs.append((cast(TraceT, trace), provider))
117+
118+
return runs

0 commit comments

Comments
 (0)