|
1 | | -# online_runner.py |
| 1 | +# -*- coding: ascii -*- |
| 2 | + |
| 3 | +""" |
| 4 | +Abstract base class for online benchmark runners. |
| 5 | +""" |
| 6 | + |
| 7 | +__author__ = "Danil Totmyanin" |
| 8 | +__copyright__ = "Copyright (c) 2026 PySATL project" |
| 9 | +__license__ = "SPDX-License-Identifier: MIT" |
| 10 | + |
2 | 11 | from abc import ABC, abstractmethod |
| 12 | +from collections.abc import Sequence |
3 | 13 | from pathlib import Path |
4 | 14 | from typing import Any |
5 | 15 |
|
|
11 | 21 |
|
12 | 22 |
|
13 | 23 | class OnlineBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledData[Any]](ABC): |
| 24 | + """ |
| 25 | + Abstract base class for online benchmark runners. |
| 26 | +
|
| 27 | + Organises the evaluation loop over algorithms and thresholds, |
| 28 | + delegates data collection to subclasses via _collect_runs(), and |
| 29 | + applies all registered metrics to each batch of runs. |
| 30 | +
|
| 31 | + Parameters |
| 32 | + ---------- |
| 33 | + algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]] |
| 34 | + Sequence of (algorithm, thresholds) pairs to evaluate. |
| 35 | + providers : Sequence[ProviderT] |
| 36 | + Sequence of labeled data providers. |
| 37 | + metrics : dict[str, MultipleRunMetric[TraceT, ProviderT, Any]] |
| 38 | + Named metrics to evaluate for each (algorithm, threshold) batch. |
| 39 | + solver : OnlineCpdSolver |
| 40 | + Solver used to run algorithms against providers. |
| 41 | + dump_dir : Path | str | None, optional |
| 42 | + Directory for caching results via BenchmarkExecutor. |
| 43 | + If None, caching is disabled. Default is None. |
| 44 | + """ |
| 45 | + |
14 | 46 | def __init__( |
15 | 47 | self, |
16 | | - algorithms: list[tuple[OnlineAlgorithm[Any, Any, Any], list[float]]], |
17 | | - providers: list[ProviderT], |
| 48 | + algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]], |
| 49 | + providers: Sequence[ProviderT], |
18 | 50 | metrics: dict[str, MultipleRunMetric[TraceT, ProviderT, Any]], |
19 | 51 | solver: OnlineCpdSolver, |
20 | | - dump_dir: Path | None = None, |
| 52 | + dump_dir: Path | str | None = None, |
21 | 53 | ) -> None: |
22 | | - return |
| 54 | + self._algorithms = algorithms |
| 55 | + self._providers = providers |
| 56 | + self._metrics = metrics |
| 57 | + self._solver = solver |
| 58 | + self._dump_dir = Path(dump_dir) if isinstance(dump_dir, str) else dump_dir |
23 | 59 |
|
24 | 60 | @abstractmethod |
25 | 61 | def _collect_runs( |
26 | 62 | self, |
27 | 63 | algorithm: OnlineAlgorithm[Any, Any, Any], |
28 | 64 | threshold: float, |
29 | | - providers: list[ProviderT], |
| 65 | + providers: Sequence[ProviderT], |
30 | 66 | ) -> list[tuple[TraceT, ProviderT]]: |
| 67 | + """ |
| 68 | + Collect (trace, provider) pairs for a given algorithm and threshold. |
| 69 | +
|
| 70 | + Parameters |
| 71 | + ---------- |
| 72 | + algorithm : OnlineAlgorithm[Any, Any, Any] |
| 73 | + The algorithm to evaluate. |
| 74 | + threshold : float |
| 75 | + The detection threshold. |
| 76 | + providers : Sequence[ProviderT] |
| 77 | + Sequence of data providers to run against. |
| 78 | +
|
| 79 | + Returns |
| 80 | + ------- |
| 81 | + list[tuple[TraceT, ProviderT]] |
| 82 | + Batch of (trace, provider) pairs for metric evaluation. |
| 83 | + """ |
| 84 | + |
31 | 85 | raise NotImplementedError("Method `_collect_runs` is not implemented yet.") |
32 | 86 |
|
33 | 87 | def run( |
34 | 88 | self, |
35 | 89 | ) -> dict[tuple[str, OnlineAlgorithmConfiguration], list[tuple[float, dict[str, Any]]]]: |
36 | | - raise NotImplementedError("Method `run` is not implemented yet.") |
| 90 | + """ |
| 91 | + Execute the benchmark over all algorithms and thresholds. |
| 92 | +
|
| 93 | + For each (algorithm, threshold) pair, collects runs via |
| 94 | + _collect_runs() and evaluates all registered metrics. |
| 95 | +
|
| 96 | + Returns |
| 97 | + ------- |
| 98 | + dict[tuple[str, OnlineAlgorithmConfiguration], list[tuple[float, dict[str, Any]]]] |
| 99 | + Mapping of (algorithm_name, configuration) to a list of |
| 100 | + (threshold, {metric_name: metric_value}) entries, one per threshold. |
| 101 | + """ |
| 102 | + |
| 103 | + results: dict[ |
| 104 | + tuple[str, OnlineAlgorithmConfiguration], |
| 105 | + list[tuple[float, dict[str, Any]]], |
| 106 | + ] = {} |
| 107 | + |
| 108 | + for algorithm, thresholds in self._algorithms: |
| 109 | + key: tuple[str, OnlineAlgorithmConfiguration] = ( |
| 110 | + str(algorithm), |
| 111 | + algorithm.configuration, |
| 112 | + ) |
| 113 | + results[key] = [] |
| 114 | + |
| 115 | + for threshold in thresholds: |
| 116 | + runs = self._collect_runs(algorithm, threshold, self._providers) |
| 117 | + |
| 118 | + metric_values: dict[str, Any] = {name: metric.evaluate(runs) for name, metric in self._metrics.items()} |
| 119 | + |
| 120 | + results[key].append((threshold, metric_values)) |
| 121 | + |
| 122 | + return results |
0 commit comments