|
| 1 | +# -*- coding: ascii -*- |
| 2 | + |
| 3 | +""" |
| 4 | +Reset benchmark runner implementation. |
| 5 | +
|
| 6 | +This module provides ResetBenchmarkRunner - a benchmark that runs the |
| 7 | +solver normally, resetting the algorithm on every detected change point. |
| 8 | +Results are cached via BenchmarkExecutor. |
| 9 | +""" |
| 10 | + |
| 11 | +__author__ = "Danil Totmyanin" |
| 12 | +__copyright__ = "Copyright (c) 2026 PySATL project" |
| 13 | +__license__ = "SPDX-License-Identifier: MIT" |
| 14 | + |
1 | 15 | from collections.abc import Sequence |
2 | 16 | from pathlib import Path |
3 | | -from typing import Any |
| 17 | +from typing import Any, cast |
4 | 18 |
|
5 | 19 | from pysatl_cpd.analysis.labeled_data import LabeledData |
| 20 | +from pysatl_cpd.benchmark.core.benchmark_executor import BenchmarkExecutor |
6 | 21 | from pysatl_cpd.benchmark.metrics.multiple_run_metric import MultipleRunMetric |
7 | 22 | from pysatl_cpd.benchmark.online_benchmark_runner import OnlineBenchmarkRunner |
8 | 23 | from pysatl_cpd.core.online.ionline_algorithm import OnlineAlgorithm |
|
13 | 28 | class ResetBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledData[Any]]( |
14 | 29 | OnlineBenchmarkRunner[TraceT, ProviderT] |
15 | 30 | ): |
| 31 | + """ |
| 32 | + Benchmark runner that uses standard reset behaviour. |
| 33 | +
|
| 34 | + For each (algorithm, threshold) pair, runs the solver over all |
| 35 | + providers via BenchmarkExecutor. The algorithm is reset on every |
| 36 | + detected change point (standard solver behaviour). Results are |
| 37 | + cached to disk when dump_dir is provided. |
| 38 | +
|
| 39 | + Parameters |
| 40 | + ---------- |
| 41 | + algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]] |
| 42 | + Sequence of (algorithm, thresholds) pairs to evaluate. |
| 43 | + providers : Sequence[ProviderT] |
| 44 | + Labeled data providers to run against. |
| 45 | + metrics : dict[str, MultipleRunMetric[TraceT, ProviderT, Any]] |
| 46 | + Named metrics to evaluate for each (algorithm, threshold) batch. |
| 47 | + solver : OnlineCpdSolver |
| 48 | + Solver used to run algorithms against providers. |
| 49 | + dump_dir : Path | str | None, optional |
| 50 | + Directory for caching results via BenchmarkExecutor. |
| 51 | + If None, caching is disabled. Default is None. |
| 52 | + """ |
| 53 | + |
16 | 54 | def __init__( |
17 | 55 | self, |
18 | 56 | algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]], |
19 | 57 | providers: Sequence[ProviderT], |
20 | 58 | metrics: dict[str, MultipleRunMetric[TraceT, ProviderT, Any]], |
21 | 59 | solver: OnlineCpdSolver, |
22 | | - dump_dir: Path | None = None, |
| 60 | + dump_dir: Path | str | None = None, |
23 | 61 | ) -> None: |
24 | | - return |
| 62 | + super().__init__( |
| 63 | + algorithms=algorithms, |
| 64 | + providers=providers, |
| 65 | + metrics=metrics, |
| 66 | + solver=solver, |
| 67 | + dump_dir=dump_dir, |
| 68 | + ) |
25 | 69 |
|
26 | 70 | def _collect_runs( |
27 | 71 | self, |
28 | 72 | algorithm: OnlineAlgorithm[Any, Any, Any], |
29 | 73 | threshold: float, |
30 | 74 | providers: Sequence[ProviderT], |
31 | 75 | ) -> list[tuple[TraceT, ProviderT]]: |
32 | | - raise NotImplementedError("Method `_collect_runs` is not implemented yet.") |
| 76 | + """ |
| 77 | + Collect runs for a given algorithm and threshold via BenchmarkExecutor. |
| 78 | +
|
| 79 | + Creates a BenchmarkExecutor with a single threshold and all providers, |
| 80 | + executes it, and pairs each resulting trace with its provider. |
| 81 | +
|
| 82 | + Parameters |
| 83 | + ---------- |
| 84 | + algorithm : OnlineAlgorithm[Any, Any, Any] |
| 85 | + The algorithm to evaluate. |
| 86 | + threshold : float |
| 87 | + The detection threshold. |
| 88 | + providers : Sequence[ProviderT] |
| 89 | + Data providers to run against. |
| 90 | +
|
| 91 | + Returns |
| 92 | + ------- |
| 93 | + list[tuple[TraceT, ProviderT]] |
| 94 | + List of (trace, provider) pairs, one per provider. |
| 95 | + """ |
| 96 | + if not providers: |
| 97 | + return [] |
| 98 | + |
| 99 | + executor: BenchmarkExecutor[Any] = BenchmarkExecutor( |
| 100 | + algorithms=[(algorithm, [threshold])], |
| 101 | + providers=list(providers), |
| 102 | + solver=self._solver, |
| 103 | + dump_dir=self._dump_dir, |
| 104 | + ) |
| 105 | + |
| 106 | + records_and_traces = executor.execute() |
| 107 | + |
| 108 | + # BenchmarkExecutor returns (BenchmarkRecord, OnlineDetectionTrace) pairs. |
| 109 | + # We need to pair each trace with the correct provider. |
| 110 | + # Executor iterates providers in the same order as input. |
| 111 | + provider_by_name: dict[str, ProviderT] = {provider.name: provider for provider in providers} |
| 112 | + |
| 113 | + runs: list[tuple[TraceT, ProviderT]] = [] |
| 114 | + for record, trace in records_and_traces: |
| 115 | + provider = provider_by_name[record.data] |
| 116 | + runs.append((cast(TraceT, trace), provider)) |
| 117 | + |
| 118 | + return runs |
0 commit comments