@@ -40,6 +40,10 @@ class OnlineCpdSolver[T, ConfugrationT: OnlineAlgorithmConfiguration, StateT: On
4040 If not ``None``, forces a change-point declaration once the run
4141 length exceeds this value. Must be positive if specified.
4242 Default is ``None``.
43+ collect_states : bool, optional
44+ Whether to collect algorithm state snapshots in step results.
45+ If False, algorithm_state will be None in all step results.
46+ Default is True.
4347
4448 Raises
4549 ------
@@ -56,6 +60,7 @@ def __init__(
5660 threshold : float = float ("nan" ),
5761 skip_period : int = 0 ,
5862 max_runlength : int | None = None ,
63+ collect_states : bool = True ,
5964 ) -> None :
6065 """
6166 Initialize the online change-point detection solver.
@@ -72,6 +77,8 @@ def __init__(
7277 Number of steps to skip after each declared change point.
7378 max_runlength : int or None, optional
7479 Maximum run length before forcing a change point.
80+ collect_states : bool, optional
81+ Whether to collect algorithm state snapshots.
7582 """
7683 # Validate skip_period is non-negative
7784 if skip_period < 0 :
@@ -86,6 +93,7 @@ def __init__(
8693 self .__threshold = threshold
8794 self .__skip_period = skip_period
8895 self .__max_runlength = max_runlength
96+ self .__collect_states = collect_states
8997
9098 self .__in_skip_period = False
9199
@@ -119,11 +127,15 @@ def run(self) -> Iterator[OnlineDetectionStepResult[StateT]]:
119127 if self .__in_skip_period :
120128 if skip_period_counter < self .__skip_period :
121129 skip_period_counter += 1
122- yield OnlineDetectionStepResult (step_num = step , is_in_skip_period = True )
130+ yield OnlineDetectionStepResult (
131+ step_num = step ,
132+ is_in_skip_period = True ,
133+ algorithm_state = self .__algorithm .state if self .__collect_states else None ,
134+ )
135+ continue
123136 if skip_period_counter == self .__skip_period :
124137 self .__in_skip_period = False
125138 skip_period_counter = 0
126- continue
127139
128140 # Process observation normally
129141 step_start_time : float = time .perf_counter ()
@@ -136,14 +148,17 @@ def run(self) -> Iterator[OnlineDetectionStepResult[StateT]]:
136148 is_change_point : bool = self ._is_change_point (detection_func , run_length )
137149 is_forced : bool = self ._is_forced_changepoint (run_length )
138150
151+ # Get algorithm state if collecting
152+ algorithm_state = self .__algorithm .state if self .__collect_states else None
153+
139154 yield OnlineDetectionStepResult (
140155 step_num = step ,
141156 is_in_skip_period = False ,
142157 is_change_point = is_change_point ,
143158 is_force_change_point = is_forced ,
144159 detection_function = detection_func ,
145160 processing_time = step_finish_time - step_start_time ,
146- algorithm_state = self . __algorithm . state ,
161+ algorithm_state = algorithm_state ,
147162 )
148163
149164 # Handle change point detection
0 commit comments