6969Если берется подтаблица, индекс должен быть приведен к непрерывному.
7070"""
7171
72- __author__ = "Andrey"
72+ __author__ = "Andrey Isakov "
7373__copyright__ = "Copyright (c) 2026 PySATL project"
7474__license__ = "SPDX-License-Identifier: MIT"
7575
7676from collections .abc import Callable , Iterator , Sequence
7777from dataclasses import dataclass , field
7878from pathlib import Path
79- from typing import Any
79+ from typing import Any , Protocol
8080
8181import numpy as np
8282import pandas as pd
8383
8484from pysatl_cpd .analysis .labeled_data import LabeledData
8585from pysatl_cpd .core .typedefs import NumericArray
8686
87- SEGMENT_COLUMN = "segments "
87+ SEGMENT_COLUMN = "segment "
8888SEGMENT_ID_COLUMN = "segment"
8989SEGMENT_START_COLUMN = "start"
9090SEGMENT_END_COLUMN = "end"
@@ -144,43 +144,42 @@ def __init__(
144144 missing_columns = required_segment_columns .difference (segment_info .columns )
145145 raise ValueError (f"Segment info is missing required columns: { sorted (missing_columns )} " )
146146
147- # TODO: one underscore for private attributes
148- self .__dataset = dataset .copy ().reset_index (drop = True )
149- self .__segment_info = segment_info .copy ().reset_index (drop = True )
150- self .__annotation = annotation
147+ self ._dataset = dataset .copy ().reset_index (drop = True )
148+ self ._segment_info = segment_info .copy ().reset_index (drop = True )
149+ self ._annotation = annotation
151150
152- self .__segment_info = self ._normalize_segment_info ()
151+ self ._segment_info = self ._normalize_segment_info ()
153152 self ._validate_segment_ranges ()
154153
155- raw_data = self .__dataset .loc [:, self .feature_columns ].to_numpy (dtype = np .float64 , copy = False )
154+ raw_data = self ._dataset .loc [:, self .feature_columns ].to_numpy (dtype = np .float64 , copy = False )
156155 super ().__init__ (raw_data = raw_data , change_points = self .change_point , name = name )
157156
158157 def __iter__ (self ) -> Iterator [NumericArray ]:
159- feature_values = self .__dataset .loc [:, self .feature_columns ].to_numpy (dtype = np .float64 , copy = False )
158+ feature_values = self ._dataset .loc [:, self .feature_columns ].to_numpy (dtype = np .float64 , copy = False )
160159 return iter (feature_values )
161160
162161 def __len__ (self ) -> int :
163- return len (self .__dataset )
162+ return len (self ._dataset )
164163
165164 @property
166165 def dataset (self ) -> pd .DataFrame :
167- return self .__dataset .copy ()
166+ return self ._dataset .copy ()
168167
169168 @property
170169 def segment_info (self ) -> DatasetSegmentInfo :
171- return self .__segment_info .copy ()
170+ return self ._segment_info .copy ()
172171
173172 @property
174173 def annotation (self ) -> Annotation :
175- return self .__annotation
174+ return self ._annotation
176175
177176 @property
178177 def feature_columns (self ) -> list [str ]:
179- return [column for column in self .__dataset .columns if column != SEGMENT_COLUMN ]
178+ return [column for column in self ._dataset .columns if column != SEGMENT_COLUMN ]
180179
181180 @property
182181 def change_point (self ) -> tuple [int , ...]:
183- segments = self .__dataset [SEGMENT_COLUMN ].to_numpy (copy = False )
182+ segments = self ._dataset [SEGMENT_COLUMN ].to_numpy (copy = False )
184183 if len (segments ) <= 1 :
185184 return ()
186185
@@ -195,13 +194,13 @@ def select_columns(self, columns: Sequence[str]) -> "PandasLabeledDataProvider":
195194 if not requested_columns :
196195 raise ValueError ("At least one feature column must be selected" )
197196
198- selected_dataset = self .__dataset .loc [:, [* requested_columns , SEGMENT_COLUMN ]].copy ().reset_index (drop = True )
199- selected_segment_info = self .__segment_info .copy ().reset_index (drop = True )
197+ selected_dataset = self ._dataset .loc [:, [* requested_columns , SEGMENT_COLUMN ]].copy ().reset_index (drop = True )
198+ selected_segment_info = self ._segment_info .copy ().reset_index (drop = True )
200199
201200 return PandasLabeledDataProvider (
202201 dataset = selected_dataset ,
203202 segment_info = selected_segment_info ,
204- annotation = self .__annotation ,
203+ annotation = self ._annotation ,
205204 name = self .name ,
206205 )
207206
@@ -215,7 +214,7 @@ def query_bisegments_indexes(self, filter_fn: SegmentFilter | None = None) -> li
215214 def query_bisegments (self , filter_fn : SegmentFilter | None = None ) -> list ["PandasLabeledDataProvider" ]:
216215 result : list [PandasLabeledDataProvider ] = []
217216 for current , next_segment in self ._iter_segment_pairs (filter_fn ):
218- sliced_dataset = self .__dataset .iloc [current .start : next_segment .end + 1 ].copy ().reset_index (drop = True )
217+ sliced_dataset = self ._dataset .iloc [current .start : next_segment .end + 1 ].copy ().reset_index (drop = True )
219218 split_index = next_segment .start - current .start
220219
221220 sliced_segment_info = pd .DataFrame (
@@ -239,16 +238,16 @@ def query_bisegments(self, filter_fn: SegmentFilter | None = None) -> list["Pand
239238 PandasLabeledDataProvider (
240239 dataset = sliced_dataset ,
241240 segment_info = sliced_segment_info ,
242- annotation = self .__annotation ,
241+ annotation = self ._annotation ,
243242 name = f"{ self .name } :{ current .segment } ->{ next_segment .segment } " ,
244243 )
245244 )
246245
247246 return result
248247
249248 def _normalize_segment_info (self ) -> DatasetSegmentInfo :
250- unique_segments = self .__dataset [SEGMENT_COLUMN ].drop_duplicates ().tolist ()
251- normalized_info = self .__segment_info .copy ()
249+ unique_segments = self ._dataset [SEGMENT_COLUMN ].drop_duplicates ().tolist ()
250+ normalized_info = self ._segment_info .copy ()
252251
253252 if SEGMENT_ID_COLUMN in normalized_info .columns :
254253 normalized_rows : list [pd .Series [Any ]] = []
@@ -270,11 +269,11 @@ def _normalize_segment_info(self) -> DatasetSegmentInfo:
270269 return normalized_info
271270
272271 def _validate_segment_ranges (self ) -> None :
273- data_length = len (self .__dataset )
272+ data_length = len (self ._dataset )
274273 if data_length == 0 :
275274 return
276275
277- for _ , segment_row in self .__segment_info .iterrows ():
276+ for _ , segment_row in self ._segment_info .iterrows ():
278277 start = int (segment_row [SEGMENT_START_COLUMN ])
279278 end = int (segment_row [SEGMENT_END_COLUMN ])
280279 if start < 0 :
@@ -297,7 +296,7 @@ def _iter_segment_pairs(self, filter_fn: SegmentFilter | None) -> list[tuple[Seg
297296
298297 def _segment_infos (self ) -> list [SegmentInfo ]:
299298 segment_infos : list [SegmentInfo ] = []
300- for _ , row in self .__segment_info .iterrows ():
299+ for _ , row in self ._segment_info .iterrows ():
301300 row_dict = row .to_dict ()
302301 start = int (row_dict .pop (SEGMENT_START_COLUMN ))
303302 end = int (row_dict .pop (SEGMENT_END_COLUMN ))
@@ -322,43 +321,75 @@ class Dataset(Sequence[PandasLabeledDataProvider]):
322321 def __init__ (
323322 self ,
324323 timeserieses : Sequence [PandasLabeledDataProvider ],
325- timeseries_preprocessor : TimeseriesPreprocessor | None = None ,
326324 ) -> None :
327- self .__timeserieses = list (timeserieses )
328- self .__timeseries_preprocessor = timeseries_preprocessor if timeseries_preprocessor is not None else _identity
329-
330- @classmethod
331- def load_from_dir (
332- cls ,
333- dir_path : Path ,
334- timeseries_preprocessor : TimeseriesPreprocessor | None = None ,
335- ) -> "Dataset" :
336- raise NotImplementedError (f"{ cls .__name__ } .load_from_dir is dataset-source specific. dir_path={ dir_path } " )
325+ self ._timeserieses = list (timeserieses )
337326
338327 def __getitem__ (self , index : int ) -> PandasLabeledDataProvider :
339- return self .__timeserieses [index ]
328+ return self ._timeserieses [index ]
340329
341330 def __len__ (self ) -> int :
342- return len (self .__timeserieses )
331+ return len (self ._timeserieses )
343332
344333 @property
345334 def timeserieses (self ) -> list [PandasLabeledDataProvider ]:
346- return list (self .__timeserieses )
347-
348- @property
349- def timeseries_preprocessor (self ) -> TimeseriesPreprocessor :
350- return self .__timeseries_preprocessor
335+ return list (self ._timeserieses )
351336
352337 def filter_by_annotation (self , annotation_filter : AnnotationFilter ) -> "Dataset" :
353- filtered_timeserieses = [provider for provider in self .__timeserieses if annotation_filter (provider .annotation )]
354- return Dataset (filtered_timeserieses , timeseries_preprocessor = self .__timeseries_preprocessor )
338+ filtered_timeserieses = [provider for provider in self ._timeserieses if annotation_filter (provider .annotation )]
339+ return Dataset (filtered_timeserieses , timeseries_preprocessor = self ._timeseries_preprocessor )
355340
356341 def select_bisegments_by_filter (self , filter_fn : SegmentFilter | None = None ) -> list [PandasLabeledDataProvider ]:
357342 bisegments : list [PandasLabeledDataProvider ] = []
358- for provider in self .__timeserieses :
343+ for provider in self ._timeserieses :
359344 bisegments .extend (provider .query_bisegments (filter_fn ))
360345 return bisegments
361346
362347
363- def _identity (frame : pd .DataFrame ) -> pd .DataFrame :
364- return frame
348+ class DatasetLoader (Protocol ):
349+ def load (self , path : Path , timeseries_preprocessor : TimeseriesPreprocessor | None = None ) -> Dataset : ...
350+
351+
352+ class RealDatasetLoader (DatasetLoader ):
353+ @staticmethod
354+ def prepare_annotation (file : str | Path ) -> Annotation :
355+ return Annotation (path = file )
356+
357+ @staticmethod
358+ def prepare_segment_info (timeseries : pd .DataFrame ) -> DatasetSegmentInfo :
359+ if SEGMENT_COLUMN not in timeseries .columns :
360+ raise ValueError (f"Timeseries must contain '{ SEGMENT_COLUMN } ' column" )
361+
362+ segments = timeseries [SEGMENT_COLUMN ].to_numpy (copy = False )
363+ if len (segments ) == 0 :
364+ return pd .DataFrame (columns = [SEGMENT_ID_COLUMN , SEGMENT_START_COLUMN , SEGMENT_END_COLUMN ])
365+
366+ change_points = np .flatnonzero (segments [1 :] != segments [:- 1 ]) + 1
367+ starts = np .concatenate ([[0 ], change_points ])
368+ ends = np .concatenate ([change_points - 1 , [len (segments ) - 1 ]])
369+ segment_ids = segments [starts ]
370+
371+ return pd .DataFrame (
372+ {
373+ SEGMENT_ID_COLUMN : segment_ids ,
374+ SEGMENT_START_COLUMN : starts ,
375+ SEGMENT_END_COLUMN : ends ,
376+ }
377+ )
378+
379+ @classmethod
380+ def load (cls , path : Path , timeseries_preprocessor : TimeseriesPreprocessor | None = None ) -> Dataset :
381+ timeseries_files = list (path .glob ("**/timeseries*.csv" ))
382+ if not timeseries_files :
383+ raise FileNotFoundError (f"No timeseries files found in { path } " )
384+
385+ timeserieses = []
386+ for file in timeseries_files :
387+ timeseries = pd .read_csv (file )
388+ if timeseries_preprocessor is not None :
389+ timeseries = timeseries_preprocessor (timeseries )
390+
391+ segment_info = cls .prepare_segment_info (timeseries )
392+ annotation = cls .prepare_annotation (file )
393+ timeserieses .append (PandasLabeledDataProvider (timeseries , segment_info = segment_info , annotation = annotation ))
394+
395+ return Dataset (timeserieses )
0 commit comments