Skip to content

Commit 06a746f

Browse files
committed
feat: add ARL calculation to example
1 parent b377e6d commit 06a746f

2 files changed

Lines changed: 150 additions & 95 deletions

File tree

examples/noreset_shewhart.py

Lines changed: 148 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,27 @@
11
"""
22
Example: Shewhart Control Chart benchmark on Normal Distribution data
3-
using NoResetBenchmarkRunner with ClassificationReport metric.
4-
5-
Dataset structure:
6-
- n rows (labeled data providers)
7-
- Each row contains one change point
8-
- Before change point: N(0, 1)
9-
- After change point: N(mu_shift, 1)
3+
using NoResetBenchmarkRunner with ClassificationReport & Delay metrics,
4+
and ARLBenchmarkRunner for Average Run Length evaluation.
105
"""
116

127
import numpy as np
138

9+
from pysatl_cpd.algorithms.online.shewhart_control_chart import ShewhartControlChart
1410
from pysatl_cpd.analysis.labeled_data import LabeledData
11+
from pysatl_cpd.benchmark.arl_benchmark_runner import ARLBenchmarkRunner
1512
from pysatl_cpd.benchmark.metrics.classification.classification_report import ClassificationReport
13+
from pysatl_cpd.benchmark.metrics.online.delay_metric import MeanDelayMetric, MedianDelayMetric
1614
from pysatl_cpd.benchmark.noreset.noreset_benchmark_runner import NoResetBenchmarkRunner
17-
from pysatl_cpd.benchmark.noreset.threshold_policy import EventBasedPolicy, PointBasedPolicy
15+
from pysatl_cpd.benchmark.noreset.threshold_policy import EventBasedPolicy
1816
from pysatl_cpd.core.online.online_cpd_solver import OnlineCpdSolver
19-
from pysatl_cpd.algorithms.online.shewhart_control_chart import ShewhartControlChart
20-
2117

2218
# ---------------------------------------------------------------------------
23-
# 1. Labeled data provider
19+
# 1. Labeled data providers
2420
# ---------------------------------------------------------------------------
2521

22+
2623
class NormalShiftProvider(LabeledData[float]):
27-
"""
28-
Labeled data provider for a single time series with one change point.
29-
30-
Before change point: N(mu_before, sigma)
31-
After change point: N(mu_after, sigma)
32-
33-
Parameters
34-
----------
35-
name : str
36-
Unique identifier for this provider.
37-
data : list[float]
38-
Pre-generated time series.
39-
change_point : int
40-
1-based index of the true change point.
41-
"""
24+
"""Provider for a single time series WITH one change point."""
4225

4326
def __init__(self, name: str, data: list[float], change_point: int) -> None:
4427
self._name = name
@@ -60,10 +43,33 @@ def __len__(self) -> int:
6043
return len(self._data)
6144

6245

46+
class NormalNullProvider(LabeledData[float]):
47+
"""Provider for a single time series WITHOUT change points (for ARL)."""
48+
49+
def __init__(self, name: str, data: list[float]) -> None:
50+
self._name = name
51+
self._data = data
52+
53+
@property
54+
def name(self) -> str:
55+
return self._name
56+
57+
@property
58+
def change_points(self) -> list[int]:
59+
return []
60+
61+
def __iter__(self):
62+
return iter(self._data)
63+
64+
def __len__(self) -> int:
65+
return len(self._data)
66+
67+
6368
# ---------------------------------------------------------------------------
6469
# 2. Dataset generation
6570
# ---------------------------------------------------------------------------
6671

72+
6773
def generate_dataset(
6874
n: int,
6975
series_length: int = 200,
@@ -73,63 +79,61 @@ def generate_dataset(
7379
sigma: float = 1.0,
7480
seed: int = 42,
7581
) -> list[NormalShiftProvider]:
76-
"""
77-
Generate n time series, each with one change point.
78-
79-
Parameters
80-
----------
81-
n : int
82-
Number of series (rows).
83-
series_length : int
84-
Total length of each series.
85-
change_point : int
86-
1-based index where the mean shifts.
87-
mu_before : float
88-
Mean before the change point.
89-
mu_after : float
90-
Mean after the change point.
91-
sigma : float
92-
Standard deviation (constant throughout).
93-
seed : int
94-
Random seed for reproducibility.
95-
96-
Returns
97-
-------
98-
list[NormalShiftProvider]
99-
List of n labeled data providers.
100-
"""
82+
"""Generate n time series, each with one change point."""
10183
rng = np.random.default_rng(seed)
10284
providers = []
10385

10486
for i in range(n):
105-
# Segment before change point (1-based: indices 1..change_point-1)
10687
n_before = change_point - 1
10788
n_after = series_length - n_before
10889

10990
before = rng.normal(mu_before, sigma, size=n_before).tolist()
11091
after = rng.normal(mu_after, sigma, size=n_after).tolist()
11192

112-
data = before + after
11393
provider = NormalShiftProvider(
11494
name=f"series_{i:04d}",
115-
data=data,
95+
data=before + after,
11696
change_point=change_point,
11797
)
11898
providers.append(provider)
11999

120100
return providers
121101

102+
103+
def generate_arl_dataset(
104+
n: int,
105+
series_length: int = 200,
106+
mu: float = 0.0,
107+
sigma: float = 1.0,
108+
seed: int = 42,
109+
) -> list[NormalNullProvider]:
110+
"""Generate n stationary time series without change points for ARL."""
111+
rng = np.random.default_rng(seed)
112+
providers = []
113+
114+
for i in range(n):
115+
data = rng.normal(mu, sigma, size=series_length).tolist()
116+
provider = NormalNullProvider(
117+
name=f"arl_series_{i:04d}",
118+
data=data,
119+
)
120+
providers.append(provider)
121+
122+
return providers
123+
124+
122125
# ---------------------------------------------------------------------------
123-
# 4. Main benchmark
126+
# 3. Main benchmark
124127
# ---------------------------------------------------------------------------
125128

129+
126130
def main() -> None:
127131
# --- Parameters ---
128-
N_SERIES = 25 # number of rows
129-
SERIES_LENGTH = 10100 # length of each series
130-
CHANGE_POINT = 10000 # 1-based change point position
132+
N_SERIES = 25
133+
SERIES_LENGTH = 10100
134+
CHANGE_POINT = 10000
131135
MU_BEFORE = 0.0
132-
MU_AFTER = 0.5 # mean shift magnitude
136+
MU_AFTER = 0.5
133137
SIGMA = 1.0
134138

135139
# Shewhart parameters
@@ -139,10 +143,11 @@ def main() -> None:
139143
# Thresholds to evaluate
140144
THRESHOLDS = np.linspace(0, 7, 30)
141145

142-
# Error margin for TP/FP/FN matching
143-
ERROR_MARGIN = (0, 100) # +/- 5 samples around true change point
146+
# Error margin for TP/FP/FN matching & Delays
147+
ERROR_MARGIN = (0, 100)
144148

145-
# --- Generate dataset ---
149+
# --- Generate datasets ---
150+
# 1. Dataset with change points for Quality and Delays
146151
providers = generate_dataset(
147152
n=N_SERIES,
148153
series_length=SERIES_LENGTH,
@@ -152,63 +157,113 @@ def main() -> None:
152157
sigma=SIGMA,
153158
seed=42,
154159
)
160+
# 2. Dataset without change points for ARL
161+
arl_providers = generate_arl_dataset(
162+
n=N_SERIES,
163+
series_length=SERIES_LENGTH,
164+
mu=MU_BEFORE,
165+
sigma=SIGMA,
166+
seed=42,
167+
)
155168

156-
print(f"Dataset: {N_SERIES} series, length={SERIES_LENGTH}, "
157-
f"change_point={CHANGE_POINT}, shift={MU_AFTER - MU_BEFORE:.1f}σ")
158-
print(f"Algorithm: ShewhartControlChart("
159-
f"learning_period={LEARNING_PERIOD}, window={WINDOW_SIZE})")
160-
print(f"Thresholds: {THRESHOLDS}")
169+
print(f"Algorithm: ShewhartControlChart(learning_period={LEARNING_PERIOD}, window={WINDOW_SIZE})")
170+
print(
171+
f"Dataset (NoReset): {N_SERIES} series, length={SERIES_LENGTH}, change_point={CHANGE_POINT}, shift={MU_AFTER - MU_BEFORE:.1f}σ"
172+
)
173+
print(f"Dataset (ARL): {N_SERIES} series, length={SERIES_LENGTH}, no change points")
161174
print(f"Error margin: {ERROR_MARGIN}")
162-
print("-" * 60)
175+
print("-" * 115)
163176

164-
# --- Algorithm ---
165177
algorithm = ShewhartControlChart(
166178
learning_period_size=LEARNING_PERIOD,
167179
window_size=WINDOW_SIZE,
168180
)
181+
solver = OnlineCpdSolver()
169182

170-
# --- Metrics ---
183+
# ==========================================
184+
# RUN 1: Classification & Delays (NoReset)
185+
# ==========================================
171186
metrics = {
172187
"classification_report": ClassificationReport(error_margin=ERROR_MARGIN),
188+
"mean_delay": MeanDelayMetric(max_delay=ERROR_MARGIN[1]),
189+
"median_delay": MedianDelayMetric(max_delay=ERROR_MARGIN[1]),
173190
}
174-
175-
# --- Policy ---
176191
policy = EventBasedPolicy(ERROR_MARGIN[1], strict_edge=False)
177192

178-
# --- Solver ---
179-
solver = OnlineCpdSolver()
180-
181-
# --- Runner ---
182193
runner = NoResetBenchmarkRunner(
183194
algorithms=[(algorithm, THRESHOLDS)],
184195
providers=providers,
185196
metrics=metrics,
186197
solver=solver,
187198
policy=policy,
188-
dump_dir="benchmark_cache/", # no caching
199+
dump_dir="benchmark_cache/noreset",
189200
)
201+
noreset_results = runner.run()
190202

191-
# --- Run ---
192-
results = runner.run()
203+
# ==========================================
204+
# RUN 2: Average Run Length (ARL)
205+
# ==========================================
206+
arl_runner = ARLBenchmarkRunner(
207+
algorithms=[(algorithm, THRESHOLDS)],
208+
providers=arl_providers,
209+
solver=solver,
210+
mode="noreset", # uses rapid point-based extraction behind the scenes
211+
dump_dir="benchmark_cache/arl",
212+
)
213+
arl_results = arl_runner.run()
214+
215+
# ==========================================
216+
# Combine and Print Results
217+
# ==========================================
193218

194-
# --- Print results ---
195-
print(f"\n{'Threshold':>10} | {'TP':>6} | {'FP':>6} | {'FN':>6} | "
196-
f"{'Precision':>10} | {'Recall':>10} | {'F1':>10}")
197-
print("-" * 70)
219+
# Structure to hold merged metrics: {threshold: {"metric_name": value}}
220+
combined_results = {}
221+
222+
# 1. Parse ARL
223+
for (_algo_name, _config), threshold_results in arl_results.items():
224+
for threshold, metric_values in threshold_results:
225+
combined_results.setdefault(threshold, {})["arl"] = metric_values["arl"]
198226

199-
for (algo_name, config), threshold_results in results.items():
227+
# 2. Parse Quality & Delays
228+
for (_algo_name, _config), threshold_results in noreset_results.items():
200229
for threshold, metric_values in threshold_results:
201-
report = metric_values["classification_report"]
202-
print(
203-
f"{threshold:>10.1f} | "
204-
f"{report['tp']:>6.0f} | "
205-
f"{report['fp']:>6.0f} | "
206-
f"{report['fn']:>6.0f} | "
207-
f"{report['precision']:>10.4f} | "
208-
f"{report['recall']:>10.4f} | "
209-
f"{report['f1']:>10.4f}"
230+
rep = metric_values["classification_report"]
231+
combined_results.setdefault(threshold, {}).update(
232+
{
233+
"tp": rep["tp"],
234+
"fp": rep["fp"],
235+
"fn": rep["fn"],
236+
"precision": rep["precision"],
237+
"recall": rep["recall"],
238+
"f1": rep["f1"],
239+
"mean_delay": metric_values["mean_delay"],
240+
"median_delay": metric_values["median_delay"],
241+
}
210242
)
211243

244+
# 3. Print unified table
245+
print(
246+
f"\n{'Threshold':>10} | {'ARL':>10} | {'TP':>4} | {'FP':>4} | {'FN':>4} | "
247+
f"{'Precision':>9} | {'Recall':>9} | {'F1':>9} | "
248+
f"{'Mean Delay':>8} | {'Med Delay':>8}"
249+
)
250+
print("-" * 115)
251+
252+
for threshold in sorted(combined_results.keys()):
253+
res = combined_results[threshold]
254+
print(
255+
f"{threshold:>10.1f} | "
256+
f"{res.get('arl', float('inf')):>10.1f} | "
257+
f"{res.get('tp', 0):>4.0f} | "
258+
f"{res.get('fp', 0):>4.0f} | "
259+
f"{res.get('fn', 0):>4.0f} | "
260+
f"{res.get('precision', 0):>9.4f} | "
261+
f"{res.get('recall', 0):>9.4f} | "
262+
f"{res.get('f1', 0):>9.4f} | "
263+
f"{res.get('mean_delay', 0):>8.1f} | "
264+
f"{res.get('median_delay', 0):>8.1f}"
265+
)
266+
212267

213268
if __name__ == "__main__":
214269
main()

pysatl_cpd/benchmark/arl_benchmark_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class ARLBenchmarkRunner[TraceT: OnlineDetectionTrace[Any], ProviderT: LabeledDa
5252
Solver used to run algorithms against providers.
5353
mode : Literal["reset", "noreset"]
5454
Evaluation mode determining whether the algorithm resets after a detection.
55-
dump_dir : Path | None, optional
55+
dump_dir : Path | str | None, optional
5656
Directory for caching results via BenchmarkExecutor.
5757
If None, caching is disabled. Default is None.
5858
@@ -70,7 +70,7 @@ def __init__(
7070
providers: list[ProviderT],
7171
solver: OnlineCpdSolver,
7272
mode: Literal["reset", "noreset"],
73-
dump_dir: Path | None = None,
73+
dump_dir: Path | str | None = None,
7474
) -> None:
7575
for provider in providers:
7676
if provider.change_points:

0 commit comments

Comments
 (0)