Skip to content

Commit dc82c16

Browse files
committed
feat: implement optional states collection
1 parent db44f84 commit dc82c16

1 file changed

Lines changed: 18 additions & 3 deletions

File tree

pysatl_cpd/online/online_cpd_solver.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ class OnlineCpdSolver[T, ConfugrationT: OnlineAlgorithmConfiguration, StateT: On
4040
If not ``None``, forces a change-point declaration once the run
4141
length exceeds this value. Must be positive if specified.
4242
Default is ``None``.
43+
collect_states : bool, optional
44+
Whether to collect algorithm state snapshots in step results.
45+
If False, algorithm_state will be None in all step results.
46+
Default is True.
4347
4448
Raises
4549
------
@@ -56,6 +60,7 @@ def __init__(
5660
threshold: float = float("nan"),
5761
skip_period: int = 0,
5862
max_runlength: int | None = None,
63+
collect_states: bool = True,
5964
) -> None:
6065
"""
6166
Initialize the online change-point detection solver.
@@ -72,6 +77,8 @@ def __init__(
7277
Number of steps to skip after each declared change point.
7378
max_runlength : int or None, optional
7479
Maximum run length before forcing a change point.
80+
collect_states : bool, optional
81+
Whether to collect algorithm state snapshots.
7582
"""
7683
# Validate skip_period is non-negative
7784
if skip_period < 0:
@@ -86,6 +93,7 @@ def __init__(
8693
self.__threshold = threshold
8794
self.__skip_period = skip_period
8895
self.__max_runlength = max_runlength
96+
self.__collect_states = collect_states
8997

9098
self.__in_skip_period = False
9199

@@ -119,11 +127,15 @@ def run(self) -> Iterator[OnlineDetectionStepResult[StateT]]:
119127
if self.__in_skip_period:
120128
if skip_period_counter < self.__skip_period:
121129
skip_period_counter += 1
122-
yield OnlineDetectionStepResult(step_num=step, is_in_skip_period=True)
130+
yield OnlineDetectionStepResult(
131+
step_num=step,
132+
is_in_skip_period=True,
133+
algorithm_state=self.__algorithm.state if self.__collect_states else None,
134+
)
135+
continue
123136
if skip_period_counter == self.__skip_period:
124137
self.__in_skip_period = False
125138
skip_period_counter = 0
126-
continue
127139

128140
# Process observation normally
129141
step_start_time: float = time.perf_counter()
@@ -136,14 +148,17 @@ def run(self) -> Iterator[OnlineDetectionStepResult[StateT]]:
136148
is_change_point: bool = self._is_change_point(detection_func, run_length)
137149
is_forced: bool = self._is_forced_changepoint(run_length)
138150

151+
# Get algorithm state if collecting
152+
algorithm_state = self.__algorithm.state if self.__collect_states else None
153+
139154
yield OnlineDetectionStepResult(
140155
step_num=step,
141156
is_in_skip_period=False,
142157
is_change_point=is_change_point,
143158
is_force_change_point=is_forced,
144159
detection_function=detection_func,
145160
processing_time=step_finish_time - step_start_time,
146-
algorithm_state=self.__algorithm.state,
161+
algorithm_state=algorithm_state,
147162
)
148163

149164
# Handle change point detection

0 commit comments

Comments
 (0)