11"""
22Example: Shewhart Control Chart benchmark on Normal Distribution data
3- using NoResetBenchmarkRunner with ClassificationReport metric.
4-
5- Dataset structure:
6- - n rows (labeled data providers)
7- - Each row contains one change point
8- - Before change point: N(0, 1)
9- - After change point: N(mu_shift, 1)
3+ using NoResetBenchmarkRunner with ClassificationReport & Delay metrics,
4+ and ARLBenchmarkRunner for Average Run Length evaluation.
105"""
116
127import numpy as np
138
9+ from pysatl_cpd .algorithms .online .shewhart_control_chart import ShewhartControlChart
1410from pysatl_cpd .analysis .labeled_data import LabeledData
11+ from pysatl_cpd .benchmark .arl_benchmark_runner import ARLBenchmarkRunner
1512from pysatl_cpd .benchmark .metrics .classification .classification_report import ClassificationReport
13+ from pysatl_cpd .benchmark .metrics .online .delay_metric import MeanDelayMetric , MedianDelayMetric
1614from pysatl_cpd .benchmark .noreset .noreset_benchmark_runner import NoResetBenchmarkRunner
17- from pysatl_cpd .benchmark .noreset .threshold_policy import EventBasedPolicy , PointBasedPolicy
15+ from pysatl_cpd .benchmark .noreset .threshold_policy import EventBasedPolicy
1816from pysatl_cpd .core .online .online_cpd_solver import OnlineCpdSolver
19- from pysatl_cpd .algorithms .online .shewhart_control_chart import ShewhartControlChart
20-
2117
2218# ---------------------------------------------------------------------------
23- # 1. Labeled data provider
19+ # 1. Labeled data providers
2420# ---------------------------------------------------------------------------
2521
22+
2623class NormalShiftProvider (LabeledData [float ]):
27- """
28- Labeled data provider for a single time series with one change point.
29-
30- Before change point: N(mu_before, sigma)
31- After change point: N(mu_after, sigma)
32-
33- Parameters
34- ----------
35- name : str
36- Unique identifier for this provider.
37- data : list[float]
38- Pre-generated time series.
39- change_point : int
40- 1-based index of the true change point.
41- """
24+ """Provider for a single time series WITH one change point."""
4225
4326 def __init__ (self , name : str , data : list [float ], change_point : int ) -> None :
4427 self ._name = name
@@ -60,10 +43,33 @@ def __len__(self) -> int:
6043 return len (self ._data )
6144
6245
46+ class NormalNullProvider (LabeledData [float ]):
47+ """Provider for a single time series WITHOUT change points (for ARL)."""
48+
49+ def __init__ (self , name : str , data : list [float ]) -> None :
50+ self ._name = name
51+ self ._data = data
52+
53+ @property
54+ def name (self ) -> str :
55+ return self ._name
56+
57+ @property
58+ def change_points (self ) -> list [int ]:
59+ return []
60+
61+ def __iter__ (self ):
62+ return iter (self ._data )
63+
64+ def __len__ (self ) -> int :
65+ return len (self ._data )
66+
67+
6368# ---------------------------------------------------------------------------
6469# 2. Dataset generation
6570# ---------------------------------------------------------------------------
6671
72+
6773def generate_dataset (
6874 n : int ,
6975 series_length : int = 200 ,
@@ -73,63 +79,61 @@ def generate_dataset(
7379 sigma : float = 1.0 ,
7480 seed : int = 42 ,
7581) -> list [NormalShiftProvider ]:
76- """
77- Generate n time series, each with one change point.
78-
79- Parameters
80- ----------
81- n : int
82- Number of series (rows).
83- series_length : int
84- Total length of each series.
85- change_point : int
86- 1-based index where the mean shifts.
87- mu_before : float
88- Mean before the change point.
89- mu_after : float
90- Mean after the change point.
91- sigma : float
92- Standard deviation (constant throughout).
93- seed : int
94- Random seed for reproducibility.
95-
96- Returns
97- -------
98- list[NormalShiftProvider]
99- List of n labeled data providers.
100- """
82+ """Generate n time series, each with one change point."""
10183 rng = np .random .default_rng (seed )
10284 providers = []
10385
10486 for i in range (n ):
105- # Segment before change point (1-based: indices 1..change_point-1)
10687 n_before = change_point - 1
10788 n_after = series_length - n_before
10889
10990 before = rng .normal (mu_before , sigma , size = n_before ).tolist ()
11091 after = rng .normal (mu_after , sigma , size = n_after ).tolist ()
11192
112- data = before + after
11393 provider = NormalShiftProvider (
11494 name = f"series_{ i :04d} " ,
115- data = data ,
95+ data = before + after ,
11696 change_point = change_point ,
11797 )
11898 providers .append (provider )
11999
120100 return providers
121101
102+
103+ def generate_arl_dataset (
104+ n : int ,
105+ series_length : int = 200 ,
106+ mu : float = 0.0 ,
107+ sigma : float = 1.0 ,
108+ seed : int = 42 ,
109+ ) -> list [NormalNullProvider ]:
110+ """Generate n stationary time series without change points for ARL."""
111+ rng = np .random .default_rng (seed )
112+ providers = []
113+
114+ for i in range (n ):
115+ data = rng .normal (mu , sigma , size = series_length ).tolist ()
116+ provider = NormalNullProvider (
117+ name = f"arl_series_{ i :04d} " ,
118+ data = data ,
119+ )
120+ providers .append (provider )
121+
122+ return providers
123+
124+
122125# ---------------------------------------------------------------------------
123- # 4 . Main benchmark
126+ # 3 . Main benchmark
124127# ---------------------------------------------------------------------------
125128
129+
126130def main () -> None :
127131 # --- Parameters ---
128- N_SERIES = 25 # number of rows
129- SERIES_LENGTH = 10100 # length of each series
130- CHANGE_POINT = 10000 # 1-based change point position
132+ N_SERIES = 25
133+ SERIES_LENGTH = 10100
134+ CHANGE_POINT = 10000
131135 MU_BEFORE = 0.0
132- MU_AFTER = 0.5 # mean shift magnitude
136+ MU_AFTER = 0.5
133137 SIGMA = 1.0
134138
135139 # Shewhart parameters
@@ -139,10 +143,11 @@ def main() -> None:
139143 # Thresholds to evaluate
140144 THRESHOLDS = np .linspace (0 , 7 , 30 )
141145
142- # Error margin for TP/FP/FN matching
143- ERROR_MARGIN = (0 , 100 ) # +/- 5 samples around true change point
146+ # Error margin for TP/FP/FN matching & Delays
147+ ERROR_MARGIN = (0 , 100 )
144148
145- # --- Generate dataset ---
149+ # --- Generate datasets ---
150+ # 1. Dataset with change points for Quality and Delays
146151 providers = generate_dataset (
147152 n = N_SERIES ,
148153 series_length = SERIES_LENGTH ,
@@ -152,63 +157,113 @@ def main() -> None:
152157 sigma = SIGMA ,
153158 seed = 42 ,
154159 )
160+ # 2. Dataset without change points for ARL
161+ arl_providers = generate_arl_dataset (
162+ n = N_SERIES ,
163+ series_length = SERIES_LENGTH ,
164+ mu = MU_BEFORE ,
165+ sigma = SIGMA ,
166+ seed = 42 ,
167+ )
155168
156- print (f"Dataset: { N_SERIES } series, length= { SERIES_LENGTH } , "
157- f"change_point= { CHANGE_POINT } , shift= { MU_AFTER - MU_BEFORE :.1f } σ" )
158- print ( f"Algorithm: ShewhartControlChart( "
159- f"learning_period= { LEARNING_PERIOD } , window= { WINDOW_SIZE } )" )
160- print (f"Thresholds: { THRESHOLDS } " )
169+ print (f"Algorithm: ShewhartControlChart(learning_period= { LEARNING_PERIOD } , window= { WINDOW_SIZE } )" )
170+ print (
171+ f"Dataset (NoReset): { N_SERIES } series, length= { SERIES_LENGTH } , change_point= { CHANGE_POINT } , shift= { MU_AFTER - MU_BEFORE :.1f } σ "
172+ )
173+ print (f"Dataset (ARL): { N_SERIES } series, length= { SERIES_LENGTH } , no change points " )
161174 print (f"Error margin: { ERROR_MARGIN } " )
162- print ("-" * 60 )
175+ print ("-" * 115 )
163176
164- # --- Algorithm ---
165177 algorithm = ShewhartControlChart (
166178 learning_period_size = LEARNING_PERIOD ,
167179 window_size = WINDOW_SIZE ,
168180 )
181+ solver = OnlineCpdSolver ()
169182
170- # --- Metrics ---
183+ # ==========================================
184+ # RUN 1: Classification & Delays (NoReset)
185+ # ==========================================
171186 metrics = {
172187 "classification_report" : ClassificationReport (error_margin = ERROR_MARGIN ),
188+ "mean_delay" : MeanDelayMetric (max_delay = ERROR_MARGIN [1 ]),
189+ "median_delay" : MedianDelayMetric (max_delay = ERROR_MARGIN [1 ]),
173190 }
174-
175- # --- Policy ---
176191 policy = EventBasedPolicy (ERROR_MARGIN [1 ], strict_edge = False )
177192
178- # --- Solver ---
179- solver = OnlineCpdSolver ()
180-
181- # --- Runner ---
182193 runner = NoResetBenchmarkRunner (
183194 algorithms = [(algorithm , THRESHOLDS )],
184195 providers = providers ,
185196 metrics = metrics ,
186197 solver = solver ,
187198 policy = policy ,
188- dump_dir = "benchmark_cache/" , # no caching
199+ dump_dir = "benchmark_cache/noreset" ,
189200 )
201+ noreset_results = runner .run ()
190202
191- # --- Run ---
192- results = runner .run ()
203+ # ==========================================
204+ # RUN 2: Average Run Length (ARL)
205+ # ==========================================
206+ arl_runner = ARLBenchmarkRunner (
207+ algorithms = [(algorithm , THRESHOLDS )],
208+ providers = arl_providers ,
209+ solver = solver ,
210+ mode = "noreset" , # uses rapid point-based extraction behind the scenes
211+ dump_dir = "benchmark_cache/arl" ,
212+ )
213+ arl_results = arl_runner .run ()
214+
215+ # ==========================================
216+ # Combine and Print Results
217+ # ==========================================
193218
194- # --- Print results ---
195- print (f"\n { 'Threshold' :>10} | { 'TP' :>6} | { 'FP' :>6} | { 'FN' :>6} | "
196- f"{ 'Precision' :>10} | { 'Recall' :>10} | { 'F1' :>10} " )
197- print ("-" * 70 )
219+ # Structure to hold merged metrics: {threshold: {"metric_name": value}}
220+ combined_results = {}
221+
222+ # 1. Parse ARL
223+ for (_algo_name , _config ), threshold_results in arl_results .items ():
224+ for threshold , metric_values in threshold_results :
225+ combined_results .setdefault (threshold , {})["arl" ] = metric_values ["arl" ]
198226
199- for (algo_name , config ), threshold_results in results .items ():
227+ # 2. Parse Quality & Delays
228+ for (_algo_name , _config ), threshold_results in noreset_results .items ():
200229 for threshold , metric_values in threshold_results :
201- report = metric_values ["classification_report" ]
202- print (
203- f"{ threshold :>10.1f} | "
204- f"{ report ['tp' ]:>6.0f} | "
205- f"{ report ['fp' ]:>6.0f} | "
206- f"{ report ['fn' ]:>6.0f} | "
207- f"{ report ['precision' ]:>10.4f} | "
208- f"{ report ['recall' ]:>10.4f} | "
209- f"{ report ['f1' ]:>10.4f} "
230+ rep = metric_values ["classification_report" ]
231+ combined_results .setdefault (threshold , {}).update (
232+ {
233+ "tp" : rep ["tp" ],
234+ "fp" : rep ["fp" ],
235+ "fn" : rep ["fn" ],
236+ "precision" : rep ["precision" ],
237+ "recall" : rep ["recall" ],
238+ "f1" : rep ["f1" ],
239+ "mean_delay" : metric_values ["mean_delay" ],
240+ "median_delay" : metric_values ["median_delay" ],
241+ }
210242 )
211243
244+ # 3. Print unified table
245+ print (
246+ f"\n { 'Threshold' :>10} | { 'ARL' :>10} | { 'TP' :>4} | { 'FP' :>4} | { 'FN' :>4} | "
247+ f"{ 'Precision' :>9} | { 'Recall' :>9} | { 'F1' :>9} | "
248+ f"{ 'Mean Delay' :>8} | { 'Med Delay' :>8} "
249+ )
250+ print ("-" * 115 )
251+
252+ for threshold in sorted (combined_results .keys ()):
253+ res = combined_results [threshold ]
254+ print (
255+ f"{ threshold :>10.1f} | "
256+ f"{ res .get ('arl' , float ('inf' )):>10.1f} | "
257+ f"{ res .get ('tp' , 0 ):>4.0f} | "
258+ f"{ res .get ('fp' , 0 ):>4.0f} | "
259+ f"{ res .get ('fn' , 0 ):>4.0f} | "
260+ f"{ res .get ('precision' , 0 ):>9.4f} | "
261+ f"{ res .get ('recall' , 0 ):>9.4f} | "
262+ f"{ res .get ('f1' , 0 ):>9.4f} | "
263+ f"{ res .get ('mean_delay' , 0 ):>8.1f} | "
264+ f"{ res .get ('median_delay' , 0 ):>8.1f} "
265+ )
266+
212267
213268if __name__ == "__main__" :
214269 main ()
0 commit comments