Skip to content

Commit b377e6d

Browse files
committed
fix: move BenchmarkExecutor inside constructor in NoResetBenchmarkRunner
1 parent c6f9fd0 commit b377e6d

3 files changed

Lines changed: 259 additions & 47 deletions

File tree

examples/noreset_shewhart.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""
2+
Example: Shewhart Control Chart benchmark on Normal Distribution data
3+
using NoResetBenchmarkRunner with ClassificationReport metric.
4+
5+
Dataset structure:
6+
- n rows (labeled data providers)
7+
- Each row contains one change point
8+
- Before change point: N(0, 1)
9+
- After change point: N(mu_shift, 1)
10+
"""
11+
12+
import numpy as np
13+
14+
from pysatl_cpd.analysis.labeled_data import LabeledData
15+
from pysatl_cpd.benchmark.metrics.classification.classification_report import ClassificationReport
16+
from pysatl_cpd.benchmark.noreset.noreset_benchmark_runner import NoResetBenchmarkRunner
17+
from pysatl_cpd.benchmark.noreset.threshold_policy import EventBasedPolicy, PointBasedPolicy
18+
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
19+
from pysatl_cpd.algorithms.online.shewhart_control_chart import ShewhartControlChart
20+
21+
22+
# ---------------------------------------------------------------------------
23+
# 1. Labeled data provider
24+
# ---------------------------------------------------------------------------
25+
26+
class NormalShiftProvider(LabeledData[float]):
27+
"""
28+
Labeled data provider for a single time series with one change point.
29+
30+
Before change point: N(mu_before, sigma)
31+
After change point: N(mu_after, sigma)
32+
33+
Parameters
34+
----------
35+
name : str
36+
Unique identifier for this provider.
37+
data : list[float]
38+
Pre-generated time series.
39+
change_point : int
40+
1-based index of the true change point.
41+
"""
42+
43+
def __init__(self, name: str, data: list[float], change_point: int) -> None:
44+
self._name = name
45+
self._data = data
46+
self._change_point = change_point
47+
48+
@property
49+
def name(self) -> str:
50+
return self._name
51+
52+
@property
53+
def change_points(self) -> list[int]:
54+
return [self._change_point]
55+
56+
def __iter__(self):
57+
return iter(self._data)
58+
59+
def __len__(self) -> int:
60+
return len(self._data)
61+
62+
63+
# ---------------------------------------------------------------------------
64+
# 2. Dataset generation
65+
# ---------------------------------------------------------------------------
66+
67+
def generate_dataset(
68+
n: int,
69+
series_length: int = 200,
70+
change_point: int = 100,
71+
mu_before: float = 0.0,
72+
mu_after: float = 3.0,
73+
sigma: float = 1.0,
74+
seed: int = 42,
75+
) -> list[NormalShiftProvider]:
76+
"""
77+
Generate n time series, each with one change point.
78+
79+
Parameters
80+
----------
81+
n : int
82+
Number of series (rows).
83+
series_length : int
84+
Total length of each series.
85+
change_point : int
86+
1-based index where the mean shifts.
87+
mu_before : float
88+
Mean before the change point.
89+
mu_after : float
90+
Mean after the change point.
91+
sigma : float
92+
Standard deviation (constant throughout).
93+
seed : int
94+
Random seed for reproducibility.
95+
96+
Returns
97+
-------
98+
list[NormalShiftProvider]
99+
List of n labeled data providers.
100+
"""
101+
rng = np.random.default_rng(seed)
102+
providers = []
103+
104+
for i in range(n):
105+
# Segment before change point (1-based: indices 1..change_point-1)
106+
n_before = change_point - 1
107+
n_after = series_length - n_before
108+
109+
before = rng.normal(mu_before, sigma, size=n_before).tolist()
110+
after = rng.normal(mu_after, sigma, size=n_after).tolist()
111+
112+
data = before + after
113+
provider = NormalShiftProvider(
114+
name=f"series_{i:04d}",
115+
data=data,
116+
change_point=change_point,
117+
)
118+
providers.append(provider)
119+
120+
return providers
121+
122+
# ---------------------------------------------------------------------------
123+
# 4. Main benchmark
124+
# ---------------------------------------------------------------------------
125+
126+
def main() -> None:
127+
# --- Parameters ---
128+
N_SERIES = 25 # number of rows
129+
SERIES_LENGTH = 10100 # length of each series
130+
CHANGE_POINT = 10000 # 1-based change point position
131+
MU_BEFORE = 0.0
132+
MU_AFTER = 0.5 # mean shift magnitude
133+
SIGMA = 1.0
134+
135+
# Shewhart parameters
136+
LEARNING_PERIOD = 1000
137+
WINDOW_SIZE = 50
138+
139+
# Thresholds to evaluate
140+
THRESHOLDS = np.linspace(0, 7, 30)
141+
142+
# Error margin for TP/FP/FN matching
143+
ERROR_MARGIN = (0, 100) # +/- 5 samples around true change point
144+
145+
# --- Generate dataset ---
146+
providers = generate_dataset(
147+
n=N_SERIES,
148+
series_length=SERIES_LENGTH,
149+
change_point=CHANGE_POINT,
150+
mu_before=MU_BEFORE,
151+
mu_after=MU_AFTER,
152+
sigma=SIGMA,
153+
seed=42,
154+
)
155+
156+
print(f"Dataset: {N_SERIES} series, length={SERIES_LENGTH}, "
157+
f"change_point={CHANGE_POINT}, shift={MU_AFTER - MU_BEFORE:.1f}σ")
158+
print(f"Algorithm: ShewhartControlChart("
159+
f"learning_period={LEARNING_PERIOD}, window={WINDOW_SIZE})")
160+
print(f"Thresholds: {THRESHOLDS}")
161+
print(f"Error margin: {ERROR_MARGIN}")
162+
print("-" * 60)
163+
164+
# --- Algorithm ---
165+
algorithm = ShewhartControlChart(
166+
learning_period_size=LEARNING_PERIOD,
167+
window_size=WINDOW_SIZE,
168+
)
169+
170+
# --- Metrics ---
171+
metrics = {
172+
"classification_report": ClassificationReport(error_margin=ERROR_MARGIN),
173+
}
174+
175+
# --- Policy ---
176+
policy = EventBasedPolicy(ERROR_MARGIN[1], strict_edge=False)
177+
178+
# --- Solver ---
179+
solver = OnlineCpdSolver()
180+
181+
# --- Runner ---
182+
runner = NoResetBenchmarkRunner(
183+
algorithms=[(algorithm, THRESHOLDS)],
184+
providers=providers,
185+
metrics=metrics,
186+
solver=solver,
187+
policy=policy,
188+
dump_dir="benchmark_cache/", # no caching
189+
)
190+
191+
# --- Run ---
192+
results = runner.run()
193+
194+
# --- Print results ---
195+
print(f"\n{'Threshold':>10} | {'TP':>6} | {'FP':>6} | {'FN':>6} | "
196+
f"{'Precision':>10} | {'Recall':>10} | {'F1':>10}")
197+
print("-" * 70)
198+
199+
for (algo_name, config), threshold_results in results.items():
200+
for threshold, metric_values in threshold_results:
201+
report = metric_values["classification_report"]
202+
print(
203+
f"{threshold:>10.1f} | "
204+
f"{report['tp']:>6.0f} | "
205+
f"{report['fp']:>6.0f} | "
206+
f"{report['fn']:>6.0f} | "
207+
f"{report['precision']:>10.4f} | "
208+
f"{report['recall']:>10.4f} | "
209+
f"{report['f1']:>10.4f}"
210+
)
211+
212+
213+
if __name__ == "__main__":
214+
main()

pysatl_cpd/benchmark/noreset/noreset_benchmark_runner.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -74,37 +74,18 @@ def __init__(
7474
)
7575
self._policy = policy
7676

77-
def _get_inf_trace(
78-
self,
79-
algorithm: OnlineAlgorithm[Any, Any, Any],
80-
provider: ProviderT,
81-
) -> OnlineDetectionTrace[Any]:
82-
"""
83-
Compute or retrieve the infinite-threshold trace for a given pair.
84-
85-
Delegates entirely to BenchmarkExecutor which handles disk caching
86-
when dump_dir is set.
87-
88-
Parameters
89-
----------
90-
algorithm : OnlineAlgorithm[Any, Any, Any]
91-
The algorithm to run.
92-
provider : ProviderT
93-
The data provider to run against.
94-
95-
Returns
96-
-------
97-
OnlineDetectionTrace[Any]
98-
Trace produced with threshold=inf.
99-
"""
10077
executor: BenchmarkExecutor[Any] = BenchmarkExecutor(
101-
algorithms=[(algorithm, [float("inf")])],
102-
providers=[provider],
78+
algorithms=[(algorithm, [float("inf")]) for algorithm, _ in algorithms],
79+
providers=list(providers),
10380
solver=self._solver,
10481
dump_dir=self._dump_dir,
10582
)
106-
_, inf_trace = executor.execute()[0]
107-
return inf_trace
83+
84+
self._inf_trace_cache: dict[tuple[str, int, str], OnlineDetectionTrace[Any]] = {}
85+
86+
for record, trace in executor.execute():
87+
key = (record.algorithm, record.configuration_hash, record.data)
88+
self._inf_trace_cache[key] = trace
10889

10990
def _collect_runs(
11091
self,
@@ -136,10 +117,13 @@ def _collect_runs(
136117
if not providers:
137118
return []
138119

120+
algo_name = str(algorithm)
121+
config_hash = hash(algorithm.configuration)
139122
runs: list[tuple[NoResetDetectionTrace[Any], ProviderT]] = []
140123

141124
for provider in providers:
142-
inf_trace = self._get_inf_trace(algorithm, provider)
125+
cache_key = (algo_name, config_hash, provider.name)
126+
inf_trace = self._inf_trace_cache[cache_key]
143127

144128
detected_change_points: list[int] = self._policy.apply(
145129
inf_trace.detection_function,

0 commit comments

Comments
 (0)