Skip to content

Commit db44f84

Browse files
committed
refactor: make online entities generic over state and configuration types
Introduce generic type parameters for state and configuration across the online detection module to improve type safety and enable better integration with future visualization components. Small changes in ShewhartControlChartConfiguration (add __repr__)
1 parent e4f74d9 commit db44f84

9 files changed

Lines changed: 278 additions & 38 deletions

pysatl_cpd/_typing.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99
__license__ = "SPDX-License-Identifier: MIT"
1010

1111

12-
from typing import Any
13-
1412
import numpy as np
1513

16-
type NumPyNumber = np.floating[Any] | np.integer[Any]
14+
type NumPyNumber = np.float64
1715
"""Type alias for NumPy numeric types."""
1816

1917
type Number = NumPyNumber | int | float

pysatl_cpd/online/ionline_algorithm.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,9 @@
77

88
from abc import ABC, abstractmethod
99
from dataclasses import dataclass
10-
from typing import TypeVar
1110

1211
from pysatl_cpd._typing import Number
1312

14-
T = TypeVar("T")
15-
1613

1714
@dataclass(kw_only=True, frozen=True)
1815
class OnlineAlgorithmState:
@@ -51,7 +48,7 @@ class OnlineAlgorithmConfiguration:
5148
learning_period_size: int = 0
5249

5350

54-
class OnlineAlgorithm[T](ABC):
51+
class OnlineAlgorithm[T, ConfigurationT: OnlineAlgorithmConfiguration, StateT: OnlineAlgorithmState](ABC):
5552
"""
5653
Abstract base class for online change-point detection algorithms.
5754
@@ -103,7 +100,7 @@ def name(self) -> str:
103100

104101
@property
105102
@abstractmethod
106-
def configuration(self) -> OnlineAlgorithmConfiguration:
103+
def configuration(self) -> ConfigurationT:
107104
"""
108105
Configuration parameters of the algorithm.
109106
@@ -115,7 +112,7 @@ def configuration(self) -> OnlineAlgorithmConfiguration:
115112
raise NotImplementedError
116113

117114
@property
118-
def state(self) -> OnlineAlgorithmState | None:
115+
def state(self) -> StateT | None:
119116
"""
120117
Current internal state snapshot of the algorithm.
121118
@@ -147,6 +144,7 @@ def process(self, observation: T) -> Number:
147144
"""
148145
raise NotImplementedError
149146

147+
@abstractmethod
150148
def reset(self) -> None:
151149
"""
152150
Reset the algorithm to its initial state.

pysatl_cpd/online/online_cpd_solver.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,14 @@
88

99
import time
1010
from collections.abc import Iterator
11-
from typing import TypeVar
1211

1312
from pysatl_cpd._typing import Number
1413
from pysatl_cpd.data_providers import DataProvider
15-
from pysatl_cpd.online.ionline_algorithm import OnlineAlgorithm
14+
from pysatl_cpd.online.ionline_algorithm import OnlineAlgorithm, OnlineAlgorithmConfiguration, OnlineAlgorithmState
1615
from pysatl_cpd.online.online_detection_trace import OnlineDetectionStepResult
1716

18-
T = TypeVar("T")
1917

20-
21-
class OnlineCpdSolver[T]:
18+
class OnlineCpdSolver[T, ConfugrationT: OnlineAlgorithmConfiguration, StateT: OnlineAlgorithmState]:
2219
"""
2320
Sequential executor for online change-point detection.
2421
@@ -55,7 +52,7 @@ class OnlineCpdSolver[T]:
5552
def __init__(
5653
self,
5754
data_provider: DataProvider[T],
58-
algorithm: OnlineAlgorithm[T],
55+
algorithm: OnlineAlgorithm[T, ConfugrationT, StateT],
5956
threshold: float = float("nan"),
6057
skip_period: int = 0,
6158
max_runlength: int | None = None,
@@ -92,7 +89,7 @@ def __init__(
9289

9390
self.__in_skip_period = False
9491

95-
def run(self) -> Iterator[OnlineDetectionStepResult]:
92+
def run(self) -> Iterator[OnlineDetectionStepResult[StateT]]:
9693
"""
9794
Execute the detection loop over all observations.
9895

pysatl_cpd/online/online_detection_trace.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
@dataclass(kw_only=True)
20-
class OnlineDetectionStepResult:
20+
class OnlineDetectionStepResult[StateT: OnlineAlgorithmState]:
2121
"""
2222
Result of processing a single observation in online changepoint detection.
2323
@@ -49,11 +49,11 @@ class OnlineDetectionStepResult:
4949
is_in_skip_period: bool = False
5050
detection_function: Number = float("nan")
5151
processing_time: Number = float("nan")
52-
algorithm_state: OnlineAlgorithmState | None = None
52+
algorithm_state: StateT | None = None
5353

5454

5555
@dataclass(kw_only=True)
56-
class OnlineDetectionTrace(DetectionTrace[Number]):
56+
class OnlineDetectionTrace[StateT: OnlineAlgorithmState](DetectionTrace[Number]):
5757
"""
5858
Complete trace of online changepoint detection execution.
5959
@@ -85,15 +85,15 @@ class OnlineDetectionTrace(DetectionTrace[Number]):
8585
threshold: Number | None = None
8686
processing_time: UnivariateNumericArray
8787
observation_scores: UnivariateNumericArray
88-
algorithm_states: list[OnlineAlgorithmState | None]
88+
algorithm_states: list[StateT | None]
8989
detected_changes: list[int]
9090
skipped_observation: list[int] = field(default_factory=list)
9191
forced_change_points: list[int] = field(default_factory=list)
9292

9393
@classmethod
9494
def from_online_detection_steps(
95-
cls, threshold: Number | None, steps: Sequence[OnlineDetectionStepResult]
96-
) -> "OnlineDetectionTrace":
95+
cls, threshold: Number | None, steps: Sequence[OnlineDetectionStepResult[StateT]]
96+
) -> "OnlineDetectionTrace[StateT]":
9797
"""
9898
Construct an OnlineDetectionTrace from a sequence of step results.
9999

pysatl_cpd/online/shewhart_control_chart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __repr__(self) -> str:
102102
return f"w = {self.window_size}"
103103

104104

105-
class ShewhartControlChart(OnlineAlgorithm[Number]):
105+
class ShewhartControlChart(OnlineAlgorithm[Number, ShewhartControlChartConfiguration, ShewhartControlChartState]):
106106
"""
107107
Shewhart control chart with sliding-window statistic.
108108

tests/online/test_ionline_algorithm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313

1414

15-
class ConcreteAlgorithm(OnlineAlgorithm[float]):
15+
class ConcreteAlgorithm(OnlineAlgorithm[float, OnlineAlgorithmConfiguration, OnlineAlgorithmState]):
1616
"""Concrete implementation of OnlineAlgorithm for testing."""
1717

1818
def __init__(

tests/online/test_online_cpd_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __iter__(self) -> Iterator[T]:
2929
return iter(self._data)
3030

3131

32-
class MockOnlineAlgorithm(OnlineAlgorithm[T]):
32+
class MockOnlineAlgorithm(OnlineAlgorithm[T, OnlineAlgorithmConfiguration, OnlineAlgorithmState]):
3333
"""Mock online algorithm for testing."""
3434

3535
def __init__(

tests/online/test_online_detection_trace.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class TestOnlineDetectionStepResult:
2424

2525
def test_default_values(self) -> None:
2626
"""Test default values for all fields."""
27-
result = OnlineDetectionStepResult()
27+
result = OnlineDetectionStepResult[OnlineAlgorithmState]()
2828

2929
assert result.step_num == 0
3030
assert result.is_change_point is False
@@ -60,7 +60,7 @@ class TestOnlineDetectionTrace:
6060
"""Test suite for OnlineDetectionTrace."""
6161

6262
@pytest.fixture
63-
def sample_steps(self) -> list[OnlineDetectionStepResult]:
63+
def sample_steps(self) -> list[OnlineDetectionStepResult[MockAlgorithmState]]:
6464
"""Create sample step results for testing."""
6565
state1 = MockAlgorithmState()
6666
state2 = MockAlgorithmState()
@@ -104,7 +104,9 @@ def sample_steps(self) -> list[OnlineDetectionStepResult]:
104104
),
105105
]
106106

107-
def test_from_online_detection_steps(self, sample_steps: list[OnlineDetectionStepResult]) -> None:
107+
def test_from_online_detection_steps(
108+
self, sample_steps: list[OnlineDetectionStepResult[MockAlgorithmState]]
109+
) -> None:
108110
"""Test constructing OnlineDetectionTrace from step results."""
109111
trace = OnlineDetectionTrace.from_online_detection_steps(threshold=0.5, steps=sample_steps)
110112

@@ -126,7 +128,7 @@ def test_from_online_detection_steps(self, sample_steps: list[OnlineDetectionSte
126128
assert trace.algorithm_states[3] is None
127129

128130
def test_from_online_detection_steps_with_none_threshold(
129-
self, sample_steps: list[OnlineDetectionStepResult]
131+
self, sample_steps: list[OnlineDetectionStepResult[MockAlgorithmState]]
130132
) -> None:
131133
"""Test constructing trace with None threshold."""
132134
trace = OnlineDetectionTrace.from_online_detection_steps(threshold=None, steps=sample_steps)
@@ -136,7 +138,7 @@ def test_from_online_detection_steps_with_none_threshold(
136138

137139
def test_from_online_detection_steps_empty(self) -> None:
138140
"""Test constructing trace from empty step sequence."""
139-
trace = OnlineDetectionTrace.from_online_detection_steps(threshold=0.5, steps=[])
141+
trace = OnlineDetectionTrace[MockAlgorithmState].from_online_detection_steps(threshold=0.5, steps=[])
140142

141143
assert trace.threshold == 0.5
142144
assert isinstance(trace.observation_scores, np.ndarray)
@@ -150,7 +152,7 @@ def test_from_online_detection_steps_empty(self) -> None:
150152

151153
def test_from_online_detection_steps_no_detections(self) -> None:
152154
"""Test constructing trace with no changepoints."""
153-
steps: list[OnlineDetectionStepResult] = [
155+
steps: list[OnlineDetectionStepResult[MockAlgorithmState]] = [
154156
OnlineDetectionStepResult(
155157
step_num=i,
156158
is_change_point=False,
@@ -208,7 +210,7 @@ def test_inherits_from_detection_trace(self) -> None:
208210
UnivariateNumericArray, np.array([0.001, 0.002, 0.003], dtype=np.float64)
209211
)
210212

211-
trace = OnlineDetectionTrace(
213+
trace = OnlineDetectionTrace[MockAlgorithmState](
212214
threshold=0.5,
213215
observation_scores=observation_scores,
214216
processing_time=processing_times,
@@ -222,7 +224,7 @@ def test_inherits_from_detection_trace(self) -> None:
222224

223225
def test_multiple_detection_types(self) -> None:
224226
"""Test trace with multiple detection types overlapping."""
225-
steps: list[OnlineDetectionStepResult] = [
227+
steps: list[OnlineDetectionStepResult[MockAlgorithmState]] = [
226228
OnlineDetectionStepResult(
227229
step_num=i,
228230
is_change_point=(i == 2),
@@ -244,7 +246,7 @@ def test_multiple_detection_types(self) -> None:
244246

245247
def test_ndarray_dtype_preservation(self) -> None:
246248
"""Test that NumPy arrays preserve float64 dtype."""
247-
steps: list[OnlineDetectionStepResult] = [
249+
steps: list[OnlineDetectionStepResult[MockAlgorithmState]] = [
248250
OnlineDetectionStepResult(
249251
step_num=i,
250252
is_change_point=False,
@@ -271,15 +273,15 @@ def test_skipped_observation_default_mutable(self) -> None:
271273
UnivariateNumericArray, np.array([0.001, 0.002], dtype=np.float64)
272274
)
273275

274-
trace1 = OnlineDetectionTrace(
276+
trace1 = OnlineDetectionTrace[MockAlgorithmState](
275277
threshold=0.5,
276278
observation_scores=observation_scores,
277279
processing_time=processing_times,
278280
algorithm_states=[],
279281
detected_changes=[1],
280282
)
281283

282-
trace2 = OnlineDetectionTrace(
284+
trace2 = OnlineDetectionTrace[MockAlgorithmState](
283285
threshold=0.5,
284286
observation_scores=observation_scores,
285287
processing_time=processing_times,
@@ -302,15 +304,15 @@ def test_forced_change_points_default_mutable(self) -> None:
302304
UnivariateNumericArray, np.array([0.001, 0.002], dtype=np.float64)
303305
)
304306

305-
trace1 = OnlineDetectionTrace(
307+
trace1 = OnlineDetectionTrace[MockAlgorithmState](
306308
threshold=0.5,
307309
observation_scores=observation_scores,
308310
processing_time=processing_times,
309311
algorithm_states=[],
310312
detected_changes=[1],
311313
)
312314

313-
trace2 = OnlineDetectionTrace(
315+
trace2 = OnlineDetectionTrace[MockAlgorithmState](
314316
threshold=0.5,
315317
observation_scores=observation_scores,
316318
processing_time=processing_times,

0 commit comments

Comments
 (0)