Skip to content

Commit 7a23584

Browse files
authored
Centralize add_segment and segments to BaseExtractor (#4462)
1 parent 962087b commit 7a23584

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+175
-187
lines changed

src/spikeinterface/core/base.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from __future__ import annotations
2-
31
from pathlib import Path
42
import shutil
53
from typing import Any
@@ -87,6 +85,8 @@ def __init__(self, main_ids: Sequence) -> None:
8785
self._main_ids.dtype.kind in "uiSU"
8886
), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}"
8987

88+
self._segments: "list[BaseSegment]" = []
89+
9090
# dict at object level
9191
self._annotations = {}
9292

@@ -142,11 +142,18 @@ def name(self, value):
142142
# we remove the annotation if it exists
143143
_ = self._annotations.pop("name", None)
144144

145+
@property
146+
def segments(self) -> "list[BaseSegment]":
147+
return self._segments
148+
149+
def add_segment(self, segment: "BaseSegment") -> None:
150+
self._segments.append(segment)
151+
segment.set_parent_extractor(self)
152+
145153
def get_num_segments(self) -> int:
146-
# This is implemented in BaseRecording or BaseSorting
147-
raise NotImplementedError
154+
return len(self._segments)
148155

149-
def get_parent(self) -> BaseExtractor | None:
156+
def get_parent(self) -> "BaseExtractor | None":
150157
"""Returns parent object if it exists, otherwise None"""
151158
return getattr(self, "_parent", None)
152159

@@ -381,7 +388,7 @@ def delete_property(self, key) -> None:
381388

382389
def copy_metadata(
383390
self,
384-
other: BaseExtractor,
391+
other: "BaseExtractor",
385392
only_main: bool = False,
386393
ids: Iterable | slice | None = None,
387394
skip_properties: Iterable[str] | None = None,
@@ -570,7 +577,7 @@ def to_dict(
570577
return dump_dict
571578

572579
@staticmethod
573-
def from_dict(dictionary: dict, base_folder: Path | str | None = None) -> BaseExtractor:
580+
def from_dict(dictionary: dict, base_folder: Path | str | None = None) -> "BaseExtractor":
574581
"""
575582
Instantiate extractor from dictionary
576583
@@ -624,7 +631,7 @@ def save_metadata_to_folder(self, folder_metadata):
624631
values = self.get_property(key)
625632
np.save(prop_folder / (key + ".npy"), values)
626633

627-
def clone(self) -> BaseExtractor:
634+
def clone(self) -> "BaseExtractor":
628635
"""
629636
Clones an existing extractor into a new instance.
630637
"""
@@ -816,7 +823,7 @@ def dump_to_pickle(
816823
file_path.write_bytes(pickle.dumps(dump_dict))
817824

818825
@staticmethod
819-
def load(file_or_folder_path: str | Path, base_folder: Path | str | bool | None = None) -> BaseExtractor:
826+
def load(file_or_folder_path: str | Path, base_folder: Path | str | bool | None = None) -> "BaseExtractor":
820827
"""
821828
Load extractor from file path (.json or .pkl)
822829
@@ -839,7 +846,7 @@ def __reduce__(self):
839846
return (instance_constructor, intialization_args)
840847

841848
@staticmethod
842-
def load_from_folder(folder) -> BaseExtractor:
849+
def load_from_folder(folder) -> "BaseExtractor":
843850
return BaseExtractor.load(folder)
844851

845852
def _save(self, folder, **save_kwargs):
@@ -855,7 +862,7 @@ def _extra_metadata_to_folder(self, folder):
855862
# This implemented in BaseRecording for probe
856863
pass
857864

858-
def save(self, **kwargs) -> BaseExtractor:
865+
def save(self, **kwargs) -> "BaseExtractor":
859866
"""
860867
Save a SpikeInterface object.
861868
@@ -891,7 +898,7 @@ def save(self, **kwargs) -> BaseExtractor:
891898

892899
save.__doc__ = save.__doc__.format(_shared_job_kwargs_doc)
893900

894-
def save_to_memory(self, sharedmem=True, **save_kwargs) -> BaseExtractor:
901+
def save_to_memory(self, sharedmem=True, **save_kwargs) -> "BaseExtractor":
895902
save_kwargs.pop("format", None)
896903

897904
cached = self._save(format="memory", sharedmem=sharedmem, **save_kwargs)
@@ -1092,7 +1099,7 @@ def save_to_zarr(
10921099
return cached
10931100

10941101

1095-
def _load_extractor_from_dict(dic) -> BaseExtractor:
1102+
def _load_extractor_from_dict(dic) -> "BaseExtractor":
10961103
"""
10971104
Convert a dictionary into an instance of BaseExtractor or its subclass.
10981105

src/spikeinterface/core/baserecording.py

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from __future__ import annotations
21
import warnings
2+
from typing import Literal
33
from pathlib import Path
44

55
import numpy as np
@@ -43,9 +43,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype):
4343
BaseRecordingSnippets.__init__(
4444
self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype
4545
)
46-
47-
self._recording_segments: list[BaseRecordingSegment] = []
48-
4946
# initialize main annotation and properties
5047
self.annotate(is_filtered=False)
5148

@@ -171,28 +168,20 @@ def __sub__(self, other):
171168

172169
return SubtractRecordings(self, other)
173170

174-
def get_num_segments(self) -> int:
175-
"""
176-
Returns the number of segments.
177-
178-
Returns
179-
-------
180-
int
181-
Number of segments in the recording
182-
"""
183-
return len(self._recording_segments)
171+
@property
172+
def segments(self) -> list["BaseRecordingSegment"]:
173+
"""List of recording segments."""
174+
return self._segments
184175

185-
def add_recording_segment(self, recording_segment):
176+
def add_recording_segment(self, recording_segment: "BaseRecordingSegment") -> None:
186177
"""Adds a recording segment.
187178
188179
Parameters
189180
----------
190181
recording_segment : BaseRecordingSegment
191182
The recording segment to add
192183
"""
193-
# todo: check channel count and sampling frequency
194-
self._recording_segments.append(recording_segment)
195-
recording_segment.set_parent_extractor(self)
184+
super().add_segment(recording_segment)
196185

197186
def get_num_samples(self, segment_index: int | None = None) -> int:
198187
"""
@@ -211,7 +200,7 @@ def get_num_samples(self, segment_index: int | None = None) -> int:
211200
The number of samples
212201
"""
213202
segment_index = self._check_segment_index(segment_index)
214-
return int(self._recording_segments[segment_index].get_num_samples())
203+
return int(self.segments[segment_index].get_num_samples())
215204

216205
get_num_frames = get_num_samples
217206

@@ -305,7 +294,7 @@ def get_traces(
305294
start_frame: int | None = None,
306295
end_frame: int | None = None,
307296
channel_ids: list | np.ndarray | tuple | None = None,
308-
order: "C" | "F" | None = None,
297+
order: Literal["C", "F"] | None = None,
309298
return_scaled: bool | None = None,
310299
return_in_uV: bool = False,
311300
) -> np.ndarray:
@@ -343,7 +332,7 @@ def get_traces(
343332
"""
344333
segment_index = self._check_segment_index(segment_index)
345334
channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True)
346-
rs = self._recording_segments[segment_index]
335+
rs = self.segments[segment_index]
347336
start_frame = int(start_frame) if start_frame is not None else 0
348337
num_samples = rs.get_num_samples()
349338
end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples
@@ -401,7 +390,7 @@ def get_time_info(self, segment_index=None) -> dict:
401390
"""
402391

403392
segment_index = self._check_segment_index(segment_index)
404-
rs = self._recording_segments[segment_index]
393+
rs = self.segments[segment_index]
405394
time_kwargs = rs.get_times_kwargs()
406395

407396
return time_kwargs
@@ -425,7 +414,7 @@ def get_times(self, segment_index=None) -> np.ndarray:
425414
The 1d times array
426415
"""
427416
segment_index = self._check_segment_index(segment_index)
428-
rs = self._recording_segments[segment_index]
417+
rs = self.segments[segment_index]
429418
times = rs.get_times()
430419
return times
431420

@@ -443,7 +432,7 @@ def get_start_time(self, segment_index=None) -> float:
443432
The start time in seconds
444433
"""
445434
segment_index = self._check_segment_index(segment_index)
446-
rs = self._recording_segments[segment_index]
435+
rs = self.segments[segment_index]
447436
return rs.get_start_time()
448437

449438
def get_end_time(self, segment_index=None) -> float:
@@ -460,7 +449,7 @@ def get_end_time(self, segment_index=None) -> float:
460449
The stop time in seconds
461450
"""
462451
segment_index = self._check_segment_index(segment_index)
463-
rs = self._recording_segments[segment_index]
452+
rs = self.segments[segment_index]
464453
return rs.get_end_time()
465454

466455
def has_time_vector(self, segment_index: int | None = None):
@@ -477,7 +466,7 @@ def has_time_vector(self, segment_index: int | None = None):
477466
True if the recording has time vectors, False otherwise
478467
"""
479468
segment_index = self._check_segment_index(segment_index)
480-
rs = self._recording_segments[segment_index]
469+
rs = self.segments[segment_index]
481470
d = rs.get_times_kwargs()
482471
return d["time_vector"] is not None
483472

@@ -494,7 +483,7 @@ def set_times(self, times, segment_index=None, with_warning=True):
494483
If True, a warning is printed
495484
"""
496485
segment_index = self._check_segment_index(segment_index)
497-
rs = self._recording_segments[segment_index]
486+
rs = self.segments[segment_index]
498487

499488
assert times.ndim == 1, "Time must have ndim=1"
500489
assert rs.get_num_samples() == times.shape[0], "times have wrong shape"
@@ -517,7 +506,7 @@ def reset_times(self):
517506
segment's sampling frequency is set to the recording's sampling frequency.
518507
"""
519508
for segment_index in range(self.get_num_segments()):
520-
rs = self._recording_segments[segment_index]
509+
rs = self.segments[segment_index]
521510
if self.has_time_vector(segment_index):
522511
rs.time_vector = None
523512
rs.t_start = None
@@ -545,7 +534,7 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N
545534
segments_to_shift = (segment_index,)
546535

547536
for segment_index in segments_to_shift:
548-
rs = self._recording_segments[segment_index]
537+
rs = self.segments[segment_index]
549538

550539
if self.has_time_vector(segment_index=segment_index):
551540
rs.time_vector += shift
@@ -558,19 +547,19 @@ def sample_index_to_time(self, sample_ind, segment_index=None):
558547
Transform sample index into time in seconds
559548
"""
560549
segment_index = self._check_segment_index(segment_index)
561-
rs = self._recording_segments[segment_index]
550+
rs = self.segments[segment_index]
562551
return rs.sample_index_to_time(sample_ind)
563552

564553
def time_to_sample_index(self, time_s, segment_index=None):
565554
segment_index = self._check_segment_index(segment_index)
566-
rs = self._recording_segments[segment_index]
555+
rs = self.segments[segment_index]
567556
return rs.time_to_sample_index(time_s)
568557

569558
def _get_t_starts(self):
570559
# handle t_starts
571560
t_starts = []
572561
has_time_vectors = []
573-
for rs in self._recording_segments:
562+
for rs in self.segments:
574563
d = rs.get_times_kwargs()
575564
t_starts.append(d["t_start"])
576565

@@ -580,7 +569,7 @@ def _get_t_starts(self):
580569

581570
def _get_time_vectors(self):
582571
time_vectors = []
583-
for rs in self._recording_segments:
572+
for rs in self.segments:
584573
d = rs.get_times_kwargs()
585574
time_vectors.append(d["time_vector"])
586575
if all(time_vector is None for time_vector in time_vectors):
@@ -668,7 +657,7 @@ def _extra_metadata_from_folder(self, folder):
668657
self.set_probegroup(probegroup, in_place=True)
669658

670659
# load time vector if any
671-
for segment_index, rs in enumerate(self._recording_segments):
660+
for segment_index, rs in enumerate(self.segments):
672661
time_file = folder / f"times_cached_seg{segment_index}.npy"
673662
if time_file.is_file():
674663
time_vector = np.load(time_file)
@@ -681,7 +670,7 @@ def _extra_metadata_to_folder(self, folder):
681670
write_probeinterface(folder / "probe.json", probegroup)
682671

683672
# save time vector if any
684-
for segment_index, rs in enumerate(self._recording_segments):
673+
for segment_index, rs in enumerate(self.segments):
685674
d = rs.get_times_kwargs()
686675
time_vector = d["time_vector"]
687676
if time_vector is not None:
@@ -735,7 +724,7 @@ def _remove_channels(self, remove_channel_ids):
735724
sub_recording = ChannelSliceRecording(self, new_channel_ids)
736725
return sub_recording
737726

738-
def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRecording:
727+
def frame_slice(self, start_frame: int | None, end_frame: int | None) -> "BaseRecording":
739728
"""
740729
Returns a new recording with sliced frames. Note that this operation is not in place.
741730
@@ -757,7 +746,7 @@ def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRec
757746
sub_recording = FrameSliceRecording(self, start_frame=start_frame, end_frame=end_frame)
758747
return sub_recording
759748

760-
def time_slice(self, start_time: float | None, end_time: float | None) -> BaseRecording:
749+
def time_slice(self, start_time: float | None, end_time: float | None) -> "BaseRecording":
761750
"""
762751
Returns a new recording object, restricted to the time interval [start_time, end_time].
763752
@@ -815,7 +804,7 @@ def _select_segments(self, segment_indices):
815804
def get_channel_locations(
816805
self,
817806
channel_ids: list | np.ndarray | tuple | None = None,
818-
axes: "xy" | "yz" | "xz" | "xyz" = "xy",
807+
axes: Literal["xy", "yz", "xz", "xyz"] = "xy",
819808
) -> np.ndarray:
820809
"""
821810
Get the physical locations of specified channels.

0 commit comments

Comments
 (0)