|
| 1 | +# -*- coding: ascii -*- |
| 2 | + |
| 3 | +""" |
| 4 | +NoReset benchmark runner implementation. |
| 5 | +
|
| 6 | +This module provides NoResetBenchmarkRunner - an optimised benchmark for |
| 7 | +series with a single change point. The solver is executed only once per |
| 8 | +(algorithm, provider) pair with threshold=inf, and all threshold |
| 9 | +evaluations are simulated via ThresholdPolicy on the cached trace. |
| 10 | +""" |
| 11 | + |
| 12 | +__author__ = "Danil Totmyanin" |
| 13 | +__copyright__ = "Copyright (c) 2026 PySATL project" |
| 14 | +__license__ = "SPDX-License-Identifier: MIT" |
| 15 | + |
1 | 16 | from collections.abc import Sequence |
2 | 17 | from pathlib import Path |
3 | 18 | from typing import Any |
4 | 19 |
|
5 | 20 | from pysatl_cpd.analysis.labeled_data import LabeledData |
| 21 | +from pysatl_cpd.benchmark.core.benchmark_executor import BenchmarkExecutor |
6 | 22 | from pysatl_cpd.benchmark.metrics.multiple_run_metric import MultipleRunMetric |
7 | 23 | from pysatl_cpd.benchmark.noreset.noreset_detection_trace import NoResetDetectionTrace |
8 | 24 | from pysatl_cpd.benchmark.noreset.threshold_policy import ThresholdPolicy |
|
13 | 29 |
|
14 | 30 |
|
15 | 31 | class NoResetBenchmarkRunner[ProviderT: LabeledData[Any]](OnlineBenchmarkRunner[NoResetDetectionTrace[Any], ProviderT]): |
| 32 | + """ |
| 33 | + Optimised benchmark runner for series with a single change point. |
| 34 | +
|
| 35 | + For each (algorithm, provider) pair the solver is executed exactly |
| 36 | + once with threshold=inf, producing a full detection function trace. |
| 37 | + All threshold evaluations are then simulated by applying a |
| 38 | + ThresholdPolicy to that cached trace, avoiding redundant solver runs. |
| 39 | + Caching is handled entirely by BenchmarkExecutor. |
| 40 | +
|
| 41 | + Parameters |
| 42 | + ---------- |
| 43 | + algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]] |
| 44 | + Sequence of (algorithm, thresholds) pairs to evaluate. |
| 45 | + providers : Sequence[ProviderT] |
| 46 | + Labeled data providers to run against. |
| 47 | + metrics : dict[str, MultipleRunMetric[NoResetDetectionTrace[Any], ProviderT, Any]] |
| 48 | + Named metrics to evaluate for each (algorithm, threshold) batch. |
| 49 | + solver : OnlineCpdSolver |
| 50 | + Solver used to produce inf traces. |
| 51 | + policy : ThresholdPolicy |
| 52 | + Policy used to extract detected change points from the inf trace |
| 53 | + for each threshold. |
| 54 | + dump_dir : Path | str | None, optional |
| 55 | + Directory for caching inf traces via BenchmarkExecutor. |
| 56 | + If None, caching is disabled. Default is None. |
| 57 | + """ |
| 58 | + |
16 | 59 | def __init__( |
17 | 60 | self, |
18 | 61 | algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]], |
19 | 62 | providers: Sequence[ProviderT], |
20 | 63 | metrics: dict[str, MultipleRunMetric[NoResetDetectionTrace[Any], ProviderT, Any]], |
21 | 64 | solver: OnlineCpdSolver, |
22 | 65 | policy: ThresholdPolicy, |
23 | | - dump_dir: Path | None = None, |
| 66 | + dump_dir: Path | str | None = None, |
24 | 67 | ) -> None: |
25 | | - return |
| 68 | + super().__init__( |
| 69 | + algorithms=algorithms, |
| 70 | + providers=providers, |
| 71 | + metrics=metrics, |
| 72 | + solver=solver, |
| 73 | + dump_dir=dump_dir, |
| 74 | + ) |
| 75 | + self._policy = policy |
| 76 | + |
| 77 | + def _get_inf_trace( |
| 78 | + self, |
| 79 | + algorithm: OnlineAlgorithm[Any, Any, Any], |
| 80 | + provider: ProviderT, |
| 81 | + ) -> OnlineDetectionTrace[Any]: |
| 82 | + """ |
| 83 | + Compute or retrieve the infinite-threshold trace for a given pair. |
| 84 | +
|
| 85 | + Delegates entirely to BenchmarkExecutor which handles disk caching |
| 86 | + when dump_dir is set. |
| 87 | +
|
| 88 | + Parameters |
| 89 | + ---------- |
| 90 | + algorithm : OnlineAlgorithm[Any, Any, Any] |
| 91 | + The algorithm to run. |
| 92 | + provider : ProviderT |
| 93 | + The data provider to run against. |
| 94 | +
|
| 95 | + Returns |
| 96 | + ------- |
| 97 | + OnlineDetectionTrace[Any] |
| 98 | + Trace produced with threshold=inf. |
| 99 | + """ |
| 100 | + executor: BenchmarkExecutor[Any] = BenchmarkExecutor( |
| 101 | + algorithms=[(algorithm, [float("inf")])], |
| 102 | + providers=[provider], |
| 103 | + solver=self._solver, |
| 104 | + dump_dir=self._dump_dir, |
| 105 | + ) |
| 106 | + _, inf_trace = executor.execute()[0] |
| 107 | + return inf_trace |
26 | 108 |
|
27 | 109 | def _collect_runs( |
28 | 110 | self, |
29 | 111 | algorithm: OnlineAlgorithm[Any, Any, Any], |
30 | 112 | threshold: float, |
31 | 113 | providers: Sequence[ProviderT], |
32 | 114 | ) -> list[tuple[NoResetDetectionTrace[Any], ProviderT]]: |
33 | | - raise NotImplementedError("Method '_collect_runs' is not implemented yet.") |
| 115 | + """ |
| 116 | + Collect NoReset runs for a given algorithm and threshold. |
34 | 117 |
|
35 | | - def _get_inf_trace( |
36 | | - self, |
37 | | - algorithm: OnlineAlgorithm[Any, Any, Any], |
38 | | - provider: ProviderT, |
39 | | - ) -> OnlineDetectionTrace[Any]: |
40 | | - raise NotImplementedError("Method '_get_inf_trace' is not implemented yet.") |
| 118 | + For each provider, retrieves the inf trace via BenchmarkExecutor |
| 119 | + and applies the ThresholdPolicy to produce a lightweight |
| 120 | + NoResetDetectionTrace. |
| 121 | +
|
| 122 | + Parameters |
| 123 | + ---------- |
| 124 | + algorithm : OnlineAlgorithm[Any, Any, Any] |
| 125 | + The algorithm to evaluate. |
| 126 | + threshold : float |
| 127 | + The detection threshold to simulate. |
| 128 | + providers : Sequence[ProviderT] |
| 129 | + Data providers to run against. |
| 130 | +
|
| 131 | + Returns |
| 132 | + ------- |
| 133 | + list[tuple[NoResetDetectionTrace[Any], ProviderT]] |
| 134 | + List of (noreset_trace, provider) pairs, one per provider. |
| 135 | + """ |
| 136 | + if not providers: |
| 137 | + return [] |
| 138 | + |
| 139 | + runs: list[tuple[NoResetDetectionTrace[Any], ProviderT]] = [] |
| 140 | + |
| 141 | + for provider in providers: |
| 142 | + inf_trace = self._get_inf_trace(algorithm, provider) |
| 143 | + |
| 144 | + detected_change_points: list[int] = self._policy.apply( |
| 145 | + inf_trace.detection_function, |
| 146 | + threshold, |
| 147 | + provider.change_points, |
| 148 | + ) |
| 149 | + |
| 150 | + noreset_trace = NoResetDetectionTrace.from_inf_trace( |
| 151 | + source_trace=inf_trace, |
| 152 | + detected_change_points=detected_change_points, |
| 153 | + threshold=threshold, |
| 154 | + ) |
| 155 | + |
| 156 | + runs.append((noreset_trace, provider)) |
| 157 | + |
| 158 | + return runs |
0 commit comments