1- from __future__ import annotations
21import warnings
2+ from typing import Literal
33from pathlib import Path
44
55import numpy as np
@@ -43,9 +43,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype):
4343 BaseRecordingSnippets .__init__ (
4444 self , channel_ids = channel_ids , sampling_frequency = sampling_frequency , dtype = dtype
4545 )
46-
47- self ._recording_segments : list [BaseRecordingSegment ] = []
48-
4946 # initialize main annotation and properties
5047 self .annotate (is_filtered = False )
5148
@@ -171,28 +168,20 @@ def __sub__(self, other):
171168
172169 return SubtractRecordings (self , other )
173170
174- def get_num_segments (self ) -> int :
175- """
176- Returns the number of segments.
177-
178- Returns
179- -------
180- int
181- Number of segments in the recording
182- """
183- return len (self ._recording_segments )
171+ @property
172+ def segments (self ) -> list ["BaseRecordingSegment" ]:
173+ """List of recording segments."""
174+ return self ._segments
184175
185- def add_recording_segment (self , recording_segment ) :
176+ def add_recording_segment (self , recording_segment : "BaseRecordingSegment" ) -> None :
186177 """Adds a recording segment.
187178
188179 Parameters
189180 ----------
190181 recording_segment : BaseRecordingSegment
191182 The recording segment to add
192183 """
193- # todo: check channel count and sampling frequency
194- self ._recording_segments .append (recording_segment )
195- recording_segment .set_parent_extractor (self )
184+ super ().add_segment (recording_segment )
196185
197186 def get_num_samples (self , segment_index : int | None = None ) -> int :
198187 """
@@ -211,7 +200,7 @@ def get_num_samples(self, segment_index: int | None = None) -> int:
211200 The number of samples
212201 """
213202 segment_index = self ._check_segment_index (segment_index )
214- return int (self ._recording_segments [segment_index ].get_num_samples ())
203+ return int (self .segments [segment_index ].get_num_samples ())
215204
216205 get_num_frames = get_num_samples
217206
@@ -305,7 +294,7 @@ def get_traces(
305294 start_frame : int | None = None ,
306295 end_frame : int | None = None ,
307296 channel_ids : list | np .ndarray | tuple | None = None ,
308- order : "C" | "F" | None = None ,
297+ order : Literal [ "C" , "F" ] | None = None ,
309298 return_scaled : bool | None = None ,
310299 return_in_uV : bool = False ,
311300 ) -> np .ndarray :
@@ -343,7 +332,7 @@ def get_traces(
343332 """
344333 segment_index = self ._check_segment_index (segment_index )
345334 channel_indices = self .ids_to_indices (channel_ids , prefer_slice = True )
346- rs = self ._recording_segments [segment_index ]
335+ rs = self .segments [segment_index ]
347336 start_frame = int (start_frame ) if start_frame is not None else 0
348337 num_samples = rs .get_num_samples ()
349338 end_frame = int (min (end_frame , num_samples )) if end_frame is not None else num_samples
@@ -401,7 +390,7 @@ def get_time_info(self, segment_index=None) -> dict:
401390 """
402391
403392 segment_index = self ._check_segment_index (segment_index )
404- rs = self ._recording_segments [segment_index ]
393+ rs = self .segments [segment_index ]
405394 time_kwargs = rs .get_times_kwargs ()
406395
407396 return time_kwargs
@@ -425,7 +414,7 @@ def get_times(self, segment_index=None) -> np.ndarray:
425414 The 1d times array
426415 """
427416 segment_index = self ._check_segment_index (segment_index )
428- rs = self ._recording_segments [segment_index ]
417+ rs = self .segments [segment_index ]
429418 times = rs .get_times ()
430419 return times
431420
@@ -443,7 +432,7 @@ def get_start_time(self, segment_index=None) -> float:
443432 The start time in seconds
444433 """
445434 segment_index = self ._check_segment_index (segment_index )
446- rs = self ._recording_segments [segment_index ]
435+ rs = self .segments [segment_index ]
447436 return rs .get_start_time ()
448437
449438 def get_end_time (self , segment_index = None ) -> float :
@@ -460,7 +449,7 @@ def get_end_time(self, segment_index=None) -> float:
460449 The stop time in seconds
461450 """
462451 segment_index = self ._check_segment_index (segment_index )
463- rs = self ._recording_segments [segment_index ]
452+ rs = self .segments [segment_index ]
464453 return rs .get_end_time ()
465454
466455 def has_time_vector (self , segment_index : int | None = None ):
@@ -477,7 +466,7 @@ def has_time_vector(self, segment_index: int | None = None):
477466 True if the recording has time vectors, False otherwise
478467 """
479468 segment_index = self ._check_segment_index (segment_index )
480- rs = self ._recording_segments [segment_index ]
469+ rs = self .segments [segment_index ]
481470 d = rs .get_times_kwargs ()
482471 return d ["time_vector" ] is not None
483472
@@ -494,7 +483,7 @@ def set_times(self, times, segment_index=None, with_warning=True):
494483 If True, a warning is printed
495484 """
496485 segment_index = self ._check_segment_index (segment_index )
497- rs = self ._recording_segments [segment_index ]
486+ rs = self .segments [segment_index ]
498487
499488 assert times .ndim == 1 , "Time must have ndim=1"
500489 assert rs .get_num_samples () == times .shape [0 ], "times have wrong shape"
@@ -517,7 +506,7 @@ def reset_times(self):
517506 segment's sampling frequency is set to the recording's sampling frequency.
518507 """
519508 for segment_index in range (self .get_num_segments ()):
520- rs = self ._recording_segments [segment_index ]
509+ rs = self .segments [segment_index ]
521510 if self .has_time_vector (segment_index ):
522511 rs .time_vector = None
523512 rs .t_start = None
@@ -545,7 +534,7 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N
545534 segments_to_shift = (segment_index ,)
546535
547536 for segment_index in segments_to_shift :
548- rs = self ._recording_segments [segment_index ]
537+ rs = self .segments [segment_index ]
549538
550539 if self .has_time_vector (segment_index = segment_index ):
551540 rs .time_vector += shift
@@ -558,19 +547,19 @@ def sample_index_to_time(self, sample_ind, segment_index=None):
558547 Transform sample index into time in seconds
559548 """
560549 segment_index = self ._check_segment_index (segment_index )
561- rs = self ._recording_segments [segment_index ]
550+ rs = self .segments [segment_index ]
562551 return rs .sample_index_to_time (sample_ind )
563552
564553 def time_to_sample_index (self , time_s , segment_index = None ):
565554 segment_index = self ._check_segment_index (segment_index )
566- rs = self ._recording_segments [segment_index ]
555+ rs = self .segments [segment_index ]
567556 return rs .time_to_sample_index (time_s )
568557
569558 def _get_t_starts (self ):
570559 # handle t_starts
571560 t_starts = []
572561 has_time_vectors = []
573- for rs in self ._recording_segments :
562+ for rs in self .segments :
574563 d = rs .get_times_kwargs ()
575564 t_starts .append (d ["t_start" ])
576565
@@ -580,7 +569,7 @@ def _get_t_starts(self):
580569
581570 def _get_time_vectors (self ):
582571 time_vectors = []
583- for rs in self ._recording_segments :
572+ for rs in self .segments :
584573 d = rs .get_times_kwargs ()
585574 time_vectors .append (d ["time_vector" ])
586575 if all (time_vector is None for time_vector in time_vectors ):
@@ -668,7 +657,7 @@ def _extra_metadata_from_folder(self, folder):
668657 self .set_probegroup (probegroup , in_place = True )
669658
670659 # load time vector if any
671- for segment_index , rs in enumerate (self ._recording_segments ):
660+ for segment_index , rs in enumerate (self .segments ):
672661 time_file = folder / f"times_cached_seg{ segment_index } .npy"
673662 if time_file .is_file ():
674663 time_vector = np .load (time_file )
@@ -681,7 +670,7 @@ def _extra_metadata_to_folder(self, folder):
681670 write_probeinterface (folder / "probe.json" , probegroup )
682671
683672 # save time vector if any
684- for segment_index , rs in enumerate (self ._recording_segments ):
673+ for segment_index , rs in enumerate (self .segments ):
685674 d = rs .get_times_kwargs ()
686675 time_vector = d ["time_vector" ]
687676 if time_vector is not None :
@@ -735,7 +724,7 @@ def _remove_channels(self, remove_channel_ids):
735724 sub_recording = ChannelSliceRecording (self , new_channel_ids )
736725 return sub_recording
737726
738- def frame_slice (self , start_frame : int | None , end_frame : int | None ) -> BaseRecording :
727+ def frame_slice (self , start_frame : int | None , end_frame : int | None ) -> " BaseRecording" :
739728 """
740729 Returns a new recording with sliced frames. Note that this operation is not in place.
741730
@@ -757,7 +746,7 @@ def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRec
757746 sub_recording = FrameSliceRecording (self , start_frame = start_frame , end_frame = end_frame )
758747 return sub_recording
759748
760- def time_slice (self , start_time : float | None , end_time : float | None ) -> BaseRecording :
749+ def time_slice (self , start_time : float | None , end_time : float | None ) -> " BaseRecording" :
761750 """
762751 Returns a new recording object, restricted to the time interval [start_time, end_time].
763752
@@ -815,7 +804,7 @@ def _select_segments(self, segment_indices):
815804 def get_channel_locations (
816805 self ,
817806 channel_ids : list | np .ndarray | tuple | None = None ,
818- axes : "xy" | "yz" | "xz" | "xyz" = "xy" ,
807+ axes : Literal [ "xy" , "yz" , "xz" , "xyz" ] = "xy" ,
819808 ) -> np .ndarray :
820809 """
821810 Get the physical locations of specified channels.
0 commit comments