Skip to content

Commit 397d69c

Browse files
Merge pull request #83 from PySATL/iraedeus/dimension-problem
feat: Add DataTransformers and update benchmarking and tests
2 parents cd27488 + d12475c commit 397d69c

28 files changed

Lines changed: 1298 additions & 475 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,4 @@ cython_debug/
170170
*.jpeg
171171

172172
assets/data
173+
benchmark_cache/

examples/noreset_shewhart.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pysatl_cpd.benchmark.metrics.online.delay_metric import MeanDelayMetric, MedianDelayMetric
1414
from pysatl_cpd.benchmark.noreset.noreset_benchmark_runner import NoResetBenchmarkRunner
1515
from pysatl_cpd.benchmark.noreset.threshold_policy import EventBasedPolicy
16+
from pysatl_cpd.core.algorithm_entry import AlgorithmEntry
1617
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
1718

1819
# ---------------------------------------------------------------------------
@@ -169,7 +170,7 @@ def main() -> None:
169170
print(f"Algorithm: ShewhartControlChart(learning_period={LEARNING_PERIOD}, window={WINDOW_SIZE})")
170171
print(
171172
f"Dataset (NoReset): {N_SERIES} series, length={SERIES_LENGTH}, change_point={CHANGE_POINT},"
172-
"shift={MU_AFTER - MU_BEFORE:.1f}*sigma"
173+
f"shift={MU_AFTER - MU_BEFORE:.1f}*sigma"
173174
)
174175
print(f"Dataset (ARL): {N_SERIES} series, length={SERIES_LENGTH}, no change points")
175176
print(f"Error margin: {ERROR_MARGIN}")
@@ -192,7 +193,7 @@ def main() -> None:
192193
policy = EventBasedPolicy(ERROR_MARGIN[1], strict_edge=False)
193194

194195
runner = NoResetBenchmarkRunner(
195-
algorithms=[(algorithm, THRESHOLDS)],
196+
entries=[AlgorithmEntry(algorithm, THRESHOLDS)],
196197
providers=providers,
197198
metrics=metrics,
198199
solver=solver,
@@ -206,7 +207,7 @@ def main() -> None:
206207
# RUN 2: Average Run Length (ARL)
207208
# ==========================================
208209
arl_runner = ARLBenchmarkRunner(
209-
algorithms=[(algorithm, THRESHOLDS)],
210+
entries=[AlgorithmEntry(algorithm, THRESHOLDS)],
210211
providers=arl_providers,
211212
solver=solver,
212213
mode="noreset", # uses rapid point-based extraction behind the scenes

pysatl_cpd/benchmark/arl_benchmark_runner.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pysatl_cpd.benchmark.noreset.threshold_policy import PointBasedPolicy
2323
from pysatl_cpd.benchmark.online_benchmark_runner import OnlineBenchmarkRunner
2424
from pysatl_cpd.benchmark.reset_benchmark_runner import ResetBenchmarkRunner
25-
from pysatl_cpd.core.online.ionline_algorithm import OnlineAlgorithm
25+
from pysatl_cpd.core.algorithm_entry import AlgorithmEntry
2626
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
2727
from pysatl_cpd.core.online.online_detection_trace import OnlineDetectionTrace
2828

@@ -44,8 +44,9 @@ class ARLBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledDa
4444
4545
Parameters
4646
----------
47-
algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]]
48-
Sequence of (algorithm, thresholds) pairs to evaluate.
47+
entries : Sequence[AlgorithmEntry]
48+
Sequence of AlgorithmEntry objects containing algorithm, thresholds,
49+
and an optional data transformer.
4950
providers : list[ProviderT]
5051
Labeled data providers to run against. Must have `change_points == []`.
5152
solver : OnlineCpdSolver
@@ -55,6 +56,8 @@ class ARLBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledDa
5556
dump_dir : Path | str | None, optional
5657
Directory for caching results via BenchmarkExecutor.
5758
If None, caching is disabled. Default is None.
59+
verbose : bool, default=False
60+
If True, displays progress bars during execution.
5861
5962
Raises
6063
------
@@ -66,7 +69,7 @@ class ARLBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledDa
6669

6770
def __init__(
6871
self,
69-
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
72+
entries: Sequence[AlgorithmEntry[Any, Any, Any]],
7073
providers: list[ProviderT],
7174
solver: OnlineCpdSolver,
7275
mode: Literal["reset", "noreset"],
@@ -83,7 +86,7 @@ def __init__(
8386
metrics = {"arl": ARLMetric[TraceT, ProviderT]()}
8487

8588
super().__init__(
86-
algorithms=algorithms,
89+
entries=entries,
8790
providers=providers,
8891
metrics=metrics, # type: ignore[arg-type]
8992
solver=solver,
@@ -95,7 +98,7 @@ def __init__(
9598
if mode == "reset":
9699
# Delegate to standard ResetBenchmarkRunner
97100
self._inner_runner: OnlineBenchmarkRunner[Any, ProviderT] = ResetBenchmarkRunner(
98-
algorithms=algorithms,
101+
entries=entries,
99102
providers=providers,
100103
metrics=cast(Any, metrics),
101104
solver=solver,
@@ -104,7 +107,7 @@ def __init__(
104107
elif mode == "noreset":
105108
# Delegate to optimized NoResetBenchmarkRunner with PointBased policy
106109
self._inner_runner = NoResetBenchmarkRunner(
107-
algorithms=algorithms,
110+
entries=entries,
108111
providers=providers,
109112
metrics=cast(Any, metrics),
110113
solver=solver,
@@ -116,20 +119,20 @@ def __init__(
116119

117120
def _collect_runs(
118121
self,
119-
algorithm: OnlineAlgorithm[Any, Any, Any],
122+
entry: AlgorithmEntry[Any, Any, Any],
120123
threshold: float,
121124
providers: Sequence[ProviderT],
122125
) -> list[tuple[TraceT, ProviderT]]:
123126
"""
124-
Collect runs for a given algorithm and threshold using the configured mode.
127+
Collect runs for a given algorithm entry and threshold using the configured mode.
125128
126129
Delegates the collection to either ResetBenchmarkRunner or
127130
NoResetBenchmarkRunner depending on the initialized mode.
128131
129132
Parameters
130133
----------
131-
algorithm : OnlineAlgorithm[Any, Any, Any]
132-
The algorithm to evaluate.
134+
entry : AlgorithmEntry
135+
The algorithm configuration entry to evaluate.
133136
threshold : float
134137
The detection threshold.
135138
providers : Sequence[ProviderT]
@@ -140,5 +143,5 @@ def _collect_runs(
140143
list[tuple[TraceT, ProviderT]]
141144
Batch of (trace, provider) pairs.
142145
"""
143-
runs = self._inner_runner._collect_runs(algorithm, threshold, providers)
146+
runs = self._inner_runner._collect_runs(entry, threshold, providers)
144147
return cast(list[tuple[TraceT, ProviderT]], runs)

pysatl_cpd/benchmark/core/benchmark_executor.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from pathlib import Path
2121
from typing import Any
2222

23+
from pysatl_cpd.core.algorithm_entry import AlgorithmEntry
2324
from pysatl_cpd.core.data_providers.idata_provider import DataProvider
24-
from pysatl_cpd.core.online.ionline_algorithm import OnlineAlgorithm
2525
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
2626
from pysatl_cpd.core.online.online_detection_trace import OnlineDetectionTrace
2727

@@ -37,7 +37,7 @@ class BenchmarkRecord:
3737
Parameters
3838
----------
3939
algorithm : str
40-
The string identifier or name of the online algorithm.
40+
The string identifier or name of the online algorithm (and transformer).
4141
configuration_hash : str
4242
A hash string representing the algorithm's configuration.
4343
data : str
@@ -79,9 +79,9 @@ class BenchmarkExecutor[DataT]:
7979
8080
Parameters
8181
----------
82-
algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]]
83-
A sequence of tuples, where each tuple contains an instantiated online
84-
algorithm and a sequence of thresholds to test it against.
82+
entries : Sequence[AlgorithmEntry]
83+
A sequence of AlgorithmEntry objects, each grouping an algorithm,
84+
its thresholds, and an optional data transformer.
8585
providers : Sequence[DataProvider[DataT]]
8686
A sequence of data providers to be fed into the algorithms.
8787
solver : OnlineCpdSolver
@@ -94,12 +94,12 @@ class BenchmarkExecutor[DataT]:
9494

9595
def __init__(
9696
self,
97-
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
97+
entries: Sequence[AlgorithmEntry[Any, Any, Any]],
9898
providers: Sequence[DataProvider[DataT]],
9999
solver: OnlineCpdSolver,
100100
dump_dir: str | Path | None = None,
101101
) -> None:
102-
self.__algorithms = algorithms
102+
self.__entries = entries
103103
self.__providers = providers
104104
self.__solver = solver
105105
self.__dump_dir = Path(dump_dir) if dump_dir is not None else None
@@ -141,12 +141,17 @@ def execute(self) -> list[tuple[BenchmarkRecord, OnlineDetectionTrace[Any]]]:
141141
)
142142
registry[record.key] = record
143143

144-
for (algorithm, thresholds), provider in itertools.product(self.__algorithms, self.__providers):
145-
algo_name = str(algorithm)
146-
config_hash = hash(algorithm.configuration)
144+
for entry, provider in itertools.product(self.__entries, self.__providers):
145+
algo_name = entry.full_name
146+
config_hash = entry.full_hash
147147
data_name = provider.name
148148

149-
for threshold in thresholds:
149+
# Apply data transformer if specified in the entry
150+
active_provider = provider
151+
if entry.transformer is not None:
152+
active_provider = entry.transformer.transform(provider)
153+
154+
for threshold in entry.thresholds:
150155
key = (algo_name, config_hash, data_name, float(threshold))
151156

152157
if key in registry:
@@ -159,7 +164,7 @@ def execute(self) -> list[tuple[BenchmarkRecord, OnlineDetectionTrace[Any]]]:
159164
results.append((registry[key], trace))
160165
continue
161166

162-
steps = list(self.__solver.run(algorithm, provider, threshold))
167+
steps = list(self.__solver.run(entry.algorithm, active_provider, threshold))
163168
trace = OnlineDetectionTrace.from_run(steps, algo_name, config_hash)
164169

165170
record = BenchmarkRecord(algo_name, config_hash, data_name, threshold, None)

pysatl_cpd/benchmark/core/benchmark_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
from typing import Any
88

9-
__author__ = "PySATL contributors"
9+
__author__ = "Danil Totmyanin"
1010
__copyright__ = "Copyright (c) 2026 PySATL project"
1111
__license__ = "SPDX-License-Identifier: MIT"
1212

pysatl_cpd/benchmark/noreset/noreset_benchmark_runner.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
__copyright__ = "Copyright (c) 2026 PySATL project"
1414
__license__ = "SPDX-License-Identifier: MIT"
1515

16+
import dataclasses
1617
from collections.abc import Sequence
1718
from pathlib import Path
1819
from typing import Any
@@ -23,7 +24,7 @@
2324
from pysatl_cpd.benchmark.noreset.noreset_detection_trace import NoResetDetectionTrace
2425
from pysatl_cpd.benchmark.noreset.threshold_policy import ThresholdPolicy
2526
from pysatl_cpd.benchmark.online_benchmark_runner import OnlineBenchmarkRunner
26-
from pysatl_cpd.core.online.ionline_algorithm import OnlineAlgorithm
27+
from pysatl_cpd.core.algorithm_entry import AlgorithmEntry
2728
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
2829
from pysatl_cpd.core.online.online_detection_trace import OnlineDetectionTrace
2930

@@ -32,16 +33,17 @@ class NoResetBenchmarkRunner[ProviderT: LabeledData[Any]](OnlineBenchmarkRunner[
3233
"""
3334
Optimised benchmark runner for series with a single change point.
3435
35-
For each (algorithm, provider) pair the solver is executed exactly
36+
For each (algorithm entry, provider) pair the solver is executed exactly
3637
once with threshold=inf, producing a full detection function trace.
3738
All threshold evaluations are then simulated by applying a
3839
ThresholdPolicy to that cached trace, avoiding redundant solver runs.
3940
Caching is handled entirely by BenchmarkExecutor.
4041
4142
Parameters
4243
----------
43-
algorithms : Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]]
44-
Sequence of (algorithm, thresholds) pairs to evaluate.
44+
entries : Sequence[AlgorithmEntry]
45+
Sequence of AlgorithmEntry objects containing algorithm, thresholds,
46+
and an optional data transformer.
4547
providers : Sequence[ProviderT]
4648
Labeled data providers to run against.
4749
metrics : dict[str, MultipleRunMetric[NoResetDetectionTrace[Any], ProviderT, Any]]
@@ -54,11 +56,13 @@ class NoResetBenchmarkRunner[ProviderT: LabeledData[Any]](OnlineBenchmarkRunner[
5456
dump_dir : Path | str | None, optional
5557
Directory for caching inf traces via BenchmarkExecutor.
5658
If None, caching is disabled. Default is None.
59+
verbose : bool, default=False
60+
If True, displays progress bars during execution.
5761
"""
5862

5963
def __init__(
6064
self,
61-
algorithms: Sequence[tuple[OnlineAlgorithm[Any, Any, Any], Sequence[float]]],
65+
entries: Sequence[AlgorithmEntry[Any, Any, Any]],
6266
providers: Sequence[ProviderT],
6367
metrics: dict[str, MultipleRunMetric[NoResetDetectionTrace[Any], ProviderT, Any]],
6468
solver: OnlineCpdSolver,
@@ -67,7 +71,7 @@ def __init__(
6771
verbose: bool = False,
6872
) -> None:
6973
super().__init__(
70-
algorithms=algorithms,
74+
entries=entries,
7175
providers=providers,
7276
metrics=metrics,
7377
solver=solver,
@@ -76,8 +80,11 @@ def __init__(
7680
)
7781
self._policy = policy
7882

83+
# Replace all thresholds with inf for initial pre-caching run
84+
inf_entries = [dataclasses.replace(entry, thresholds=[float("inf")]) for entry in entries]
85+
7986
executor: BenchmarkExecutor[Any] = BenchmarkExecutor(
80-
algorithms=[(algorithm, [float("inf")]) for algorithm, _ in algorithms],
87+
entries=inf_entries,
8188
providers=list(providers),
8289
solver=self._solver,
8390
dump_dir=self._dump_dir,
@@ -86,26 +93,27 @@ def __init__(
8693
self._inf_trace_cache: dict[tuple[str, int, str], OnlineDetectionTrace[Any]] = {}
8794

8895
for record, trace in executor.execute():
96+
# record.algorithm maps to entry.full_name, hash maps to entry.full_hash
8997
key = (record.algorithm, record.configuration_hash, record.data)
9098
self._inf_trace_cache[key] = trace
9199

92100
def _collect_runs(
93101
self,
94-
algorithm: OnlineAlgorithm[Any, Any, Any],
102+
entry: AlgorithmEntry[Any, Any, Any],
95103
threshold: float,
96104
providers: Sequence[ProviderT],
97105
) -> list[tuple[NoResetDetectionTrace[Any], ProviderT]]:
98106
"""
99-
Collect NoReset runs for a given algorithm and threshold.
107+
Collect NoReset runs for a given algorithm entry and threshold.
100108
101109
For each provider, retrieves the inf trace via BenchmarkExecutor
102110
and applies the ThresholdPolicy to produce a lightweight
103111
NoResetDetectionTrace.
104112
105113
Parameters
106114
----------
107-
algorithm : OnlineAlgorithm[Any, Any, Any]
108-
The algorithm to evaluate.
115+
entry : AlgorithmEntry
116+
The algorithm configuration entry to evaluate.
109117
threshold : float
110118
The detection threshold to simulate.
111119
providers : Sequence[ProviderT]
@@ -119,8 +127,8 @@ def _collect_runs(
119127
if not providers:
120128
return []
121129

122-
algo_name = str(algorithm)
123-
config_hash = hash(algorithm.configuration)
130+
algo_name = entry.full_name
131+
config_hash = entry.full_hash
124132
runs: list[tuple[NoResetDetectionTrace[Any], ProviderT]] = []
125133

126134
for provider in providers:

0 commit comments

Comments
 (0)