Skip to content

Commit ec55b00

Browse files
Add SelectSegmentEvent (#4575)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0e5226c commit ec55b00

2 files changed

Lines changed: 26 additions & 0 deletions

File tree

src/spikeinterface/core/baseevent.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def add_event_segment(self, event_segment):
6666
def get_num_segments(self):
6767
return len(self._event_segments)
6868

69+
def select_segment(self, segment_indices: int | list[int]):
70+
from .segmentutils import SelectSegmentEvent
71+
72+
return SelectSegmentEvent(self, segment_indices=segment_indices)
73+
6974
def get_events(
7075
self,
7176
channel_id: int | str | None = None,

src/spikeinterface/core/segmentutils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
from .baseevent import BaseEvent
34
from .baserecording import BaseRecording, BaseRecordingSegment
45
from .basesorting import BaseSorting, BaseSortingSegment
56

@@ -604,3 +605,23 @@ def __init__(self, sorting: BaseSorting, segment_indices: int | list[int]):
604605

605606

606607
select_segment_sorting = define_function_from_class(source_class=SelectSegmentSorting, name="select_segment_sorting")
608+
609+
610+
class SelectSegmentEvent(BaseEvent):
611+
def __init__(self, event: BaseEvent, segment_indices: int | list[int]):
612+
BaseEvent.__init__(self, event.channel_ids, event.structured_dtype)
613+
614+
if isinstance(segment_indices, int):
615+
segment_indices = [segment_indices]
616+
617+
num_segments = event.get_num_segments()
618+
619+
if not all(0 <= s < num_segments for s in segment_indices):
620+
raise ValueError(f"'segment_index' must be between 0 and {num_segments - 1}")
621+
622+
for seg_idx in segment_indices:
623+
seg = event._event_segments[seg_idx]
624+
self.add_event_segment(seg)
625+
626+
self._parent = event
627+
self._kwargs = {"event": event, "segment_indices": segment_indices}

0 commit comments

Comments
 (0)