Skip to content

Commit 44a5fae

Browse files
Merge pull request #82 from PySATL/aisakov/dataset-loader
feat: Protocol for DataLoader Added
2 parents 397d69c + b22aebf commit 44a5fae

2 files changed

Lines changed: 82 additions & 50 deletions

File tree

pysatl_cpd/core/data_providers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
__license__ = "SPDX-License-Identifier: MIT"
1313

1414

15-
from pysatl_cpd.core.data_providers.dataset import Annotation, Dataset, PandasLabeledDataProvider, SegmentInfo
15+
from pysatl_cpd.core.data_providers.dataset import Annotation, Dataset, PandasLabeledDataProvider, RealDatasetLoader, SegmentInfo
1616
from pysatl_cpd.core.data_providers.idata_provider import DataProvider
1717
from pysatl_cpd.core.data_providers.numpy_data_provider import (
1818
NDArrayMultivariateProvider,
@@ -25,6 +25,7 @@
2525
"SegmentInfo",
2626
"PandasLabeledDataProvider",
2727
"Dataset",
28+
"RealDatasetLoader",
2829
"NDArrayMultivariateProvider",
2930
"NDArrayUnivariateProvider",
3031
]

pysatl_cpd/core/data_providers/dataset.py

Lines changed: 80 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -69,22 +69,22 @@
6969
Если берется подтаблица, индекс должен быть приведен к непрерывному.
7070
"""
7171

72-
__author__ = "Andrey"
72+
__author__ = "Andrey Isakov"
7373
__copyright__ = "Copyright (c) 2026 PySATL project"
7474
__license__ = "SPDX-License-Identifier: MIT"
7575

7676
from collections.abc import Callable, Iterator, Sequence
7777
from dataclasses import dataclass, field
7878
from pathlib import Path
79-
from typing import Any
79+
from typing import Any, Protocol
8080

8181
import numpy as np
8282
import pandas as pd
8383

8484
from pysatl_cpd.analysis.labeled_data import LabeledData
8585
from pysatl_cpd.core.typedefs import NumericArray
8686

87-
SEGMENT_COLUMN = "segments"
87+
SEGMENT_COLUMN = "segment"
8888
SEGMENT_ID_COLUMN = "segment"
8989
SEGMENT_START_COLUMN = "start"
9090
SEGMENT_END_COLUMN = "end"
@@ -144,43 +144,42 @@ def __init__(
144144
missing_columns = required_segment_columns.difference(segment_info.columns)
145145
raise ValueError(f"Segment info is missing required columns: {sorted(missing_columns)}")
146146

147-
# TODO: one underscore for private attributes
148-
self.__dataset = dataset.copy().reset_index(drop=True)
149-
self.__segment_info = segment_info.copy().reset_index(drop=True)
150-
self.__annotation = annotation
147+
self._dataset = dataset.copy().reset_index(drop=True)
148+
self._segment_info = segment_info.copy().reset_index(drop=True)
149+
self._annotation = annotation
151150

152-
self.__segment_info = self._normalize_segment_info()
151+
self._segment_info = self._normalize_segment_info()
153152
self._validate_segment_ranges()
154153

155-
raw_data = self.__dataset.loc[:, self.feature_columns].to_numpy(dtype=np.float64, copy=False)
154+
raw_data = self._dataset.loc[:, self.feature_columns].to_numpy(dtype=np.float64, copy=False)
156155
super().__init__(raw_data=raw_data, change_points=self.change_point, name=name)
157156

158157
def __iter__(self) -> Iterator[NumericArray]:
159-
feature_values = self.__dataset.loc[:, self.feature_columns].to_numpy(dtype=np.float64, copy=False)
158+
feature_values = self._dataset.loc[:, self.feature_columns].to_numpy(dtype=np.float64, copy=False)
160159
return iter(feature_values)
161160

162161
def __len__(self) -> int:
163-
return len(self.__dataset)
162+
return len(self._dataset)
164163

165164
@property
166165
def dataset(self) -> pd.DataFrame:
167-
return self.__dataset.copy()
166+
return self._dataset.copy()
168167

169168
@property
170169
def segment_info(self) -> DatasetSegmentInfo:
171-
return self.__segment_info.copy()
170+
return self._segment_info.copy()
172171

173172
@property
174173
def annotation(self) -> Annotation:
175-
return self.__annotation
174+
return self._annotation
176175

177176
@property
178177
def feature_columns(self) -> list[str]:
179-
return [column for column in self.__dataset.columns if column != SEGMENT_COLUMN]
178+
return [column for column in self._dataset.columns if column != SEGMENT_COLUMN]
180179

181180
@property
182181
def change_point(self) -> tuple[int, ...]:
183-
segments = self.__dataset[SEGMENT_COLUMN].to_numpy(copy=False)
182+
segments = self._dataset[SEGMENT_COLUMN].to_numpy(copy=False)
184183
if len(segments) <= 1:
185184
return ()
186185

@@ -195,13 +194,13 @@ def select_columns(self, columns: Sequence[str]) -> "PandasLabeledDataProvider":
195194
if not requested_columns:
196195
raise ValueError("At least one feature column must be selected")
197196

198-
selected_dataset = self.__dataset.loc[:, [*requested_columns, SEGMENT_COLUMN]].copy().reset_index(drop=True)
199-
selected_segment_info = self.__segment_info.copy().reset_index(drop=True)
197+
selected_dataset = self._dataset.loc[:, [*requested_columns, SEGMENT_COLUMN]].copy().reset_index(drop=True)
198+
selected_segment_info = self._segment_info.copy().reset_index(drop=True)
200199

201200
return PandasLabeledDataProvider(
202201
dataset=selected_dataset,
203202
segment_info=selected_segment_info,
204-
annotation=self.__annotation,
203+
annotation=self._annotation,
205204
name=self.name,
206205
)
207206

@@ -215,7 +214,7 @@ def query_bisegments_indexes(self, filter_fn: SegmentFilter | None = None) -> li
215214
def query_bisegments(self, filter_fn: SegmentFilter | None = None) -> list["PandasLabeledDataProvider"]:
216215
result: list[PandasLabeledDataProvider] = []
217216
for current, next_segment in self._iter_segment_pairs(filter_fn):
218-
sliced_dataset = self.__dataset.iloc[current.start : next_segment.end + 1].copy().reset_index(drop=True)
217+
sliced_dataset = self._dataset.iloc[current.start : next_segment.end + 1].copy().reset_index(drop=True)
219218
split_index = next_segment.start - current.start
220219

221220
sliced_segment_info = pd.DataFrame(
@@ -239,16 +238,16 @@ def query_bisegments(self, filter_fn: SegmentFilter | None = None) -> list["Pand
239238
PandasLabeledDataProvider(
240239
dataset=sliced_dataset,
241240
segment_info=sliced_segment_info,
242-
annotation=self.__annotation,
241+
annotation=self._annotation,
243242
name=f"{self.name}:{current.segment}->{next_segment.segment}",
244243
)
245244
)
246245

247246
return result
248247

249248
def _normalize_segment_info(self) -> DatasetSegmentInfo:
250-
unique_segments = self.__dataset[SEGMENT_COLUMN].drop_duplicates().tolist()
251-
normalized_info = self.__segment_info.copy()
249+
unique_segments = self._dataset[SEGMENT_COLUMN].drop_duplicates().tolist()
250+
normalized_info = self._segment_info.copy()
252251

253252
if SEGMENT_ID_COLUMN in normalized_info.columns:
254253
normalized_rows: list[pd.Series[Any]] = []
@@ -270,11 +269,11 @@ def _normalize_segment_info(self) -> DatasetSegmentInfo:
270269
return normalized_info
271270

272271
def _validate_segment_ranges(self) -> None:
273-
data_length = len(self.__dataset)
272+
data_length = len(self._dataset)
274273
if data_length == 0:
275274
return
276275

277-
for _, segment_row in self.__segment_info.iterrows():
276+
for _, segment_row in self._segment_info.iterrows():
278277
start = int(segment_row[SEGMENT_START_COLUMN])
279278
end = int(segment_row[SEGMENT_END_COLUMN])
280279
if start < 0:
@@ -297,7 +296,7 @@ def _iter_segment_pairs(self, filter_fn: SegmentFilter | None) -> list[tuple[Seg
297296

298297
def _segment_infos(self) -> list[SegmentInfo]:
299298
segment_infos: list[SegmentInfo] = []
300-
for _, row in self.__segment_info.iterrows():
299+
for _, row in self._segment_info.iterrows():
301300
row_dict = row.to_dict()
302301
start = int(row_dict.pop(SEGMENT_START_COLUMN))
303302
end = int(row_dict.pop(SEGMENT_END_COLUMN))
@@ -322,43 +321,75 @@ class Dataset(Sequence[PandasLabeledDataProvider]):
322321
def __init__(
323322
self,
324323
timeserieses: Sequence[PandasLabeledDataProvider],
325-
timeseries_preprocessor: TimeseriesPreprocessor | None = None,
326324
) -> None:
327-
self.__timeserieses = list(timeserieses)
328-
self.__timeseries_preprocessor = timeseries_preprocessor if timeseries_preprocessor is not None else _identity
329-
330-
@classmethod
331-
def load_from_dir(
332-
cls,
333-
dir_path: Path,
334-
timeseries_preprocessor: TimeseriesPreprocessor | None = None,
335-
) -> "Dataset":
336-
raise NotImplementedError(f"{cls.__name__}.load_from_dir is dataset-source specific. dir_path={dir_path}")
325+
self._timeserieses = list(timeserieses)
337326

338327
def __getitem__(self, index: int) -> PandasLabeledDataProvider:
339-
return self.__timeserieses[index]
328+
return self._timeserieses[index]
340329

341330
def __len__(self) -> int:
342-
return len(self.__timeserieses)
331+
return len(self._timeserieses)
343332

344333
@property
345334
def timeserieses(self) -> list[PandasLabeledDataProvider]:
346-
return list(self.__timeserieses)
347-
348-
@property
349-
def timeseries_preprocessor(self) -> TimeseriesPreprocessor:
350-
return self.__timeseries_preprocessor
335+
return list(self._timeserieses)
351336

352337
def filter_by_annotation(self, annotation_filter: AnnotationFilter) -> "Dataset":
353-
filtered_timeserieses = [provider for provider in self.__timeserieses if annotation_filter(provider.annotation)]
354-
return Dataset(filtered_timeserieses, timeseries_preprocessor=self.__timeseries_preprocessor)
338+
filtered_timeserieses = [provider for provider in self._timeserieses if annotation_filter(provider.annotation)]
339+
return Dataset(filtered_timeserieses, timeseries_preprocessor=self._timeseries_preprocessor)
355340

356341
def select_bisegments_by_filter(self, filter_fn: SegmentFilter | None = None) -> list[PandasLabeledDataProvider]:
357342
bisegments: list[PandasLabeledDataProvider] = []
358-
for provider in self.__timeserieses:
343+
for provider in self._timeserieses:
359344
bisegments.extend(provider.query_bisegments(filter_fn))
360345
return bisegments
361346

362347

363-
def _identity(frame: pd.DataFrame) -> pd.DataFrame:
364-
return frame
348+
class DatasetLoader(Protocol):
349+
def load(self, path: Path, timeseries_preprocessor: TimeseriesPreprocessor | None = None) -> Dataset: ...
350+
351+
352+
class RealDatasetLoader(DatasetLoader):
353+
@staticmethod
354+
def prepare_annotation(file: str | Path) -> Annotation:
355+
return Annotation(path=file)
356+
357+
@staticmethod
358+
def prepare_segment_info(timeseries: pd.DataFrame) -> DatasetSegmentInfo:
359+
if SEGMENT_COLUMN not in timeseries.columns:
360+
raise ValueError(f"Timeseries must contain '{SEGMENT_COLUMN}' column")
361+
362+
segments = timeseries[SEGMENT_COLUMN].to_numpy(copy=False)
363+
if len(segments) == 0:
364+
return pd.DataFrame(columns=[SEGMENT_ID_COLUMN, SEGMENT_START_COLUMN, SEGMENT_END_COLUMN])
365+
366+
change_points = np.flatnonzero(segments[1:] != segments[:-1]) + 1
367+
starts = np.concatenate([[0], change_points])
368+
ends = np.concatenate([change_points - 1, [len(segments) - 1]])
369+
segment_ids = segments[starts]
370+
371+
return pd.DataFrame(
372+
{
373+
SEGMENT_ID_COLUMN: segment_ids,
374+
SEGMENT_START_COLUMN: starts,
375+
SEGMENT_END_COLUMN: ends,
376+
}
377+
)
378+
379+
@classmethod
380+
def load(cls, path: Path, timeseries_preprocessor: TimeseriesPreprocessor | None = None) -> Dataset:
381+
timeseries_files = list(path.glob("**/timeseries*.csv"))
382+
if not timeseries_files:
383+
raise FileNotFoundError(f"No timeseries files found in {path}")
384+
385+
timeserieses = []
386+
for file in timeseries_files:
387+
timeseries = pd.read_csv(file)
388+
if timeseries_preprocessor is not None:
389+
timeseries = timeseries_preprocessor(timeseries)
390+
391+
segment_info = cls.prepare_segment_info(timeseries)
392+
annotation = cls.prepare_annotation(file)
393+
timeserieses.append(PandasLabeledDataProvider(timeseries, segment_info=segment_info, annotation=annotation))
394+
395+
return Dataset(timeserieses)

0 commit comments

Comments
 (0)