Skip to content

Commit ab29ba3

Browse files
committed
refactor: name property in providers
1 parent 9eceec6 commit ab29ba3

8 files changed

Lines changed: 67 additions & 121 deletions

File tree

pysatl_cpd/analysis/labeled_data.py

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
__license__ = "SPDX-License-Identifier: MIT"
99

1010
from collections.abc import Collection, Iterator, Sequence
11-
from dataclasses import dataclass
1211

1312
from pysatl_cpd.core.data_providers import DataProvider
1413

1514

16-
@dataclass
1715
class LabeledData[T](DataProvider[T]):
1816
"""
1917
Container for labeled time series data with known change point locations.
@@ -51,37 +49,23 @@ class LabeledData[T](DataProvider[T]):
5149
6
5250
"""
5351

54-
raw_data: Collection[T]
55-
change_points: Sequence[int]
56-
name: str | None = None
52+
def __init__(self, raw_data: Collection[T], change_points: Sequence[int], name: str | None = None):
53+
super().__init__(name)
5754

58-
def __post_init__(self) -> None:
59-
"""
60-
Validate change point indices after initialization.
61-
62-
Verifies that all change point indices are positive and within
63-
the bounds of the raw data. Change points must be >= 1 because
64-
they represent the first observation index after a regime change,
65-
and index 0 would imply a change before any data exists.
66-
67-
Raises
68-
------
69-
ValueError
70-
If any change point index is <= 0.
71-
ValueError
72-
If any change point index exceeds the length of raw_data.
73-
"""
7455
# Validate that change points are positive
75-
if self.change_points and min(self.change_points) <= 0:
76-
raise ValueError(f"Change point indices must be positive (>= 1). Found index: {min(self.change_points)}")
56+
if change_points and min(change_points) <= 0:
57+
raise ValueError(f"Change point indices must be positive (>= 1). Found index: {min(change_points)}")
7758

7859
# Validate that change points are within data bounds
79-
max_index = max(self.change_points) if self.change_points else 0
80-
if max_index > len(self.raw_data):
60+
max_index = max(change_points) if change_points else 0
61+
if max_index > len(raw_data):
8162
raise ValueError(
82-
f"Change point index exceeds data length. Max index: {max_index}, data length: {len(self.raw_data)}"
63+
f"Change point index exceeds data length. Max index: {max_index}, data length: {len(raw_data)}"
8364
)
8465

66+
self.__raw_data = raw_data
67+
self.__change_points = change_points
68+
8569
def __iter__(self) -> Iterator[T]:
8670
"""
8771
Return an iterator over the raw observations.
@@ -91,7 +75,7 @@ def __iter__(self) -> Iterator[T]:
9175
Iterator[T]
9276
Iterator yielding each observation in sequence.
9377
"""
94-
return iter(self.raw_data)
78+
return iter(self.__raw_data)
9579

9680
def __len__(self) -> int:
9781
"""
@@ -102,7 +86,7 @@ def __len__(self) -> int:
10286
int
10387
Total number of observations.
10488
"""
105-
return len(self.raw_data)
89+
return len(self.__raw_data)
10690

10791
def __str__(self) -> str:
10892
"""
@@ -113,6 +97,12 @@ def __str__(self) -> str:
11397
str
11498
String representation with dataset name (if provided) and length.
11599
"""
116-
if self.name is not None:
117-
return f"{self.name} (len = {len(self)})"
118-
return f"Labeled Data (len = {len(self)})"
100+
return f"{self.name} (len = {len(self)})"
101+
102+
@property
103+
def raw_data(self) -> Collection[T]:
104+
return self.__raw_data
105+
106+
@property
107+
def change_points(self) -> Sequence[int]:
108+
return self.__change_points

pysatl_cpd/core/data_providers/idata_provider.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class DataProvider[T](ABC):
3434
For multivariate data, T is typically a one-dimensional array.
3535
"""
3636

37+
def __init__(self, name: str | None) -> None:
38+
self._name = name if name is not None else type(self).__name__
39+
3740
@abstractmethod
3841
def __iter__(self) -> Iterator[T]:
3942
"""
@@ -59,7 +62,6 @@ def __len__(self) -> int:
5962
raise NotImplementedError # pragma: no cover
6063

6164
@property
62-
@abstractmethod
6365
def name(self) -> str:
6466
"""Return the name of the DataProvider"""
65-
raise NotImplementedError # pragma: no cover
67+
return self._name

pysatl_cpd/core/data_providers/numpy_data_provider.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class NDArrayUnivariateProvider(DataProvider[NumPyNumber]):
4949
[1.0, 2.0, 3.0, 4.0, 5.0]
5050
"""
5151

52-
def __init__(self, data: NumericArray) -> None:
52+
def __init__(self, data: NumericArray, name: str | None = None) -> None:
5353
"""
5454
Initialize the univariate provider with a NumPy array.
5555
@@ -63,6 +63,8 @@ def __init__(self, data: NumericArray) -> None:
6363
ValueError
6464
If the array is not one-dimensional.
6565
"""
66+
super().__init__(name)
67+
6668
if data.ndim != 1:
6769
raise ValueError(f"Expected 1-dimensional array, got {data.ndim} dimensions")
6870
self.__data = cast(UnivariateNumericArray, data)
@@ -89,10 +91,6 @@ def __len__(self) -> int:
8991
"""
9092
return self.__data.shape[0]
9193

92-
@property
93-
def name(self) -> str:
94-
return "NDArrayUnivariateProvider"
95-
9694

9795
class NDArrayMultivariateProvider(DataProvider[UnivariateNumericArray]):
9896
"""
@@ -126,7 +124,7 @@ class NDArrayMultivariateProvider(DataProvider[UnivariateNumericArray]):
126124
[array([1., 2.]), array([3., 4.]), array([5., 6.])]
127125
"""
128126

129-
def __init__(self, data: NumericArray) -> None:
127+
def __init__(self, data: NumericArray, name: str | None = None) -> None:
130128
"""
131129
Initialize the multivariate provider with a NumPy array.
132130
@@ -140,6 +138,8 @@ def __init__(self, data: NumericArray) -> None:
140138
ValueError
141139
If the array is not two-dimensional.
142140
"""
141+
super().__init__(name)
142+
143143
if data.ndim != 2:
144144
raise ValueError(f"Expected 2 dimensions, got {data.ndim}")
145145
self.__data = cast(MultivariateNumericArray, data)
@@ -166,7 +166,3 @@ def __len__(self) -> int:
166166
The number of rows in the underlying array (number of observations).
167167
"""
168168
return self.__data.shape[0]
169-
170-
@property
171-
def name(self) -> str:
172-
return "NDArrayMultivariateProvider"

tests/mocks/core/data_providers/dirty.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def __init__(
4040
source: DataProvider[Number],
4141
nan_indices: list[int] | None = None,
4242
inf_indices: list[int] | None = None,
43+
name: str | None = None,
4344
) -> None:
45+
super().__init__(name)
46+
4447
self._source = source
4548
self._nan_indices = set(nan_indices or [])
4649
self._inf_indices = set(inf_indices or [])
@@ -82,10 +85,6 @@ def __getitem__(self, index: int) -> Number:
8285
return float("inf")
8386
return self._data[index]
8487

85-
@property
86-
def name(self) -> str:
87-
return "MockUnivariateDirtyDataProvider"
88-
8988
def get_call_count(self) -> int:
9089
"""Return number of times __iter__ was called."""
9190
return self._call_count
@@ -126,7 +125,10 @@ def __init__(
126125
source: DataProvider[list[Number]],
127126
nan_positions: list[tuple[int, int]] | None = None,
128127
inf_positions: list[tuple[int, int]] | None = None,
128+
name: str | None = None,
129129
) -> None:
130+
super().__init__(name)
131+
130132
self._source = source
131133
self._nan_positions = set(nan_positions or [])
132134
self._inf_positions = set(inf_positions or [])
@@ -199,10 +201,6 @@ def dimensions(self) -> int:
199201
"""Return number of dimensions (variables)."""
200202
return self._dimensions
201203

202-
@property
203-
def name(self) -> str:
204-
return "MockMultivariateDirtyDataProvider"
205-
206204
def __repr__(self) -> str:
207205
"""Return string representation."""
208206
return (

tests/mocks/core/data_providers/edge.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
class MockEmptyDataProvider[T](DataProvider[T]):
1818
"""Mock data provider that yields no observations."""
1919

20-
def __init__(self) -> None:
20+
def __init__(self, name: str | None = None) -> None:
21+
super().__init__(name)
22+
2123
self._call_count = 0
2224

2325
def __iter__(self) -> Iterator[T]:
@@ -37,10 +39,6 @@ def reset_call_count(self) -> None:
3739
"""Reset the call counter."""
3840
self._call_count = 0
3941

40-
@property
41-
def name(self) -> str:
42-
return "MockEmptyDataProvider"
43-
4442
def __repr__(self) -> str:
4543
"""Return string representation."""
4644
return "MockEmptyDataProvider()"
@@ -49,7 +47,9 @@ def __repr__(self) -> str:
4947
class MockSingleObservationProvider[T](DataProvider[T]):
5048
"""Mock data provider with a single observation."""
5149

52-
def __init__(self, observation: T) -> None:
50+
def __init__(self, observation: T, name: str | None = None) -> None:
51+
super().__init__(name)
52+
5353
self._observation = observation
5454
self._call_count = 0
5555

@@ -70,10 +70,6 @@ def reset_call_count(self) -> None:
7070
"""Reset the call counter."""
7171
self._call_count = 0
7272

73-
@property
74-
def name(self) -> str:
75-
return "MockSingleObservationProvider"
76-
7773
def __repr__(self) -> str:
7874
"""Return string representation."""
7975
return f"MockSingleObservationProvider(value={self._observation})"

tests/mocks/core/data_providers/multivariate.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class MockMultivariateDataProvider(DataProvider[list[Number]]):
2929
All observations must have the same length.
3030
"""
3131

32-
def __init__(self, data: Sequence[Sequence[Number]]) -> None:
32+
def __init__(self, data: Sequence[Sequence[Number]], name: str | None = None) -> None:
33+
super().__init__(name)
34+
3335
if not data:
3436
self._data: list[list[Number]] = []
3537
self._dimensions = 0
@@ -76,10 +78,6 @@ def dimensions(self) -> int:
7678
"""Return number of dimensions (variables)."""
7779
return self._dimensions
7880

79-
@property
80-
def name(self) -> str:
81-
return "MockMultivariateDataProvider"
82-
8381
def __repr__(self) -> str:
8482
"""Return string representation."""
8583
return f"MockMultivariateDataProvider(observations={len(self)}, dimensions={self.dimensions})"
@@ -99,14 +97,10 @@ class MockMultivariateConstantDataProvider(MockMultivariateDataProvider):
9997
Number of observations to yield.
10098
"""
10199

102-
def __init__(self, value: list[Number], length: int) -> None:
100+
def __init__(self, value: list[Number], length: int, name: str | None = None) -> None:
103101
self._value = value
104102
self._length = length
105-
super().__init__([value] * length)
106-
107-
@property
108-
def name(self) -> str:
109-
return "MockMultivariateConstantDataProvider"
103+
super().__init__([value] * length, name)
110104

111105

112106
class MockMultivariateZeroDataProvider(MockMultivariateConstantDataProvider):
@@ -123,13 +117,9 @@ class MockMultivariateZeroDataProvider(MockMultivariateConstantDataProvider):
123117
Number of observations to yield.
124118
"""
125119

126-
def __init__(self, dimensions: int, length: int) -> None:
120+
def __init__(self, dimensions: int, length: int, name: str | None = None) -> None:
127121
value = [0.0] * dimensions
128-
super().__init__(value, length)
129-
130-
@property
131-
def name(self) -> str:
132-
return "MockMultivariateZeroDataProvider"
122+
super().__init__(value, length, name)
133123

134124

135125
class MockMultivariateNaNDataProvider(MockMultivariateConstantDataProvider):
@@ -146,13 +136,9 @@ class MockMultivariateNaNDataProvider(MockMultivariateConstantDataProvider):
146136
Number of observations to yield.
147137
"""
148138

149-
def __init__(self, dimensions: int, length: int) -> None:
139+
def __init__(self, dimensions: int, length: int, name: str | None = None) -> None:
150140
value = [float("nan")] * dimensions
151-
super().__init__(value, length)
152-
153-
@property
154-
def name(self) -> str:
155-
return "MockMultivariateNaNDataProvider"
141+
super().__init__(value, length, name)
156142

157143

158144
class MockMultivariateInfDataProvider(MockMultivariateConstantDataProvider):
@@ -169,10 +155,6 @@ class MockMultivariateInfDataProvider(MockMultivariateConstantDataProvider):
169155
Number of observations to yield.
170156
"""
171157

172-
def __init__(self, dimensions: int, length: int) -> None:
158+
def __init__(self, dimensions: int, length: int, name: str | None = None) -> None:
173159
value = [float("inf")] * dimensions
174-
super().__init__(value, length)
175-
176-
@property
177-
def name(self) -> str:
178-
return "MockMultivariateInfDataProvider"
160+
super().__init__(value, length, name)

0 commit comments

Comments
 (0)