Skip to content

Commit 651c8ea

Browse files
Add get_start_time, get_end_time, and get_last_spike_frame to BaseSorting (#4525)
Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent a3f8db3 commit 651c8ea

2 files changed

Lines changed: 123 additions & 0 deletions

File tree

src/spikeinterface/core/basesorting.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,11 @@ def register_recording(self, recording, check_spike_frames: bool = True):
327327
"Might be necessary for further postprocessing."
328328
)
329329
self._recording = recording
330+
# Copy the recording's start times into the sorting segments. This way,
331+
# the sorting preserves the start time even if the recording is later
332+
# detached (e.g. analyzer saved and reloaded without the recording).
333+
for segment_index, segment in enumerate(self.segments):
334+
segment._t_start = recording.get_start_time(segment_index=segment_index)
330335

331336
@property
332337
def sorting_info(self):
@@ -352,6 +357,66 @@ def has_time_vector(self, segment_index: int | None = None) -> bool:
352357
else:
353358
return False
354359

360+
def get_start_time(self, segment_index: int | None = None) -> float:
361+
"""Get the start time of the sorting segment.
362+
363+
Parameters
364+
----------
365+
segment_index : int or None, default: None
366+
The segment index (required for multi-segment)
367+
368+
Returns
369+
-------
370+
float
371+
The start time in seconds
372+
"""
373+
segment_index = self._check_segment_index(segment_index)
374+
segment = self.segments[segment_index]
375+
return segment._t_start if segment._t_start is not None else 0.0
376+
377+
def get_end_time(self, segment_index: int | None = None) -> float:
378+
"""Get the end time of the sorting segment.
379+
380+
If a recording is registered, returns the recording's end time.
381+
Otherwise returns the time of the last spike in the segment.
382+
383+
Parameters
384+
----------
385+
segment_index : int or None, default: None
386+
The segment index (required for multi-segment)
387+
388+
Returns
389+
-------
390+
float
391+
The end time in seconds
392+
"""
393+
segment_index = self._check_segment_index(segment_index)
394+
if self.has_recording():
395+
return self._recording.get_end_time(segment_index=segment_index)
396+
else:
397+
last_spike_frame = self.get_last_spike_frame(segment_index=segment_index)
398+
return self.sample_index_to_time(last_spike_frame, segment_index=segment_index)
399+
400+
def get_last_spike_frame(self, segment_index: int | None = None) -> int:
401+
"""Get the frame index of the last spike in a segment across all units.
402+
403+
Parameters
404+
----------
405+
segment_index : int or None, default: None
406+
The segment index (required for multi-segment)
407+
408+
Returns
409+
-------
410+
int
411+
The frame index of the last spike, or 0 if no spikes exist.
412+
"""
413+
segment_index = self._check_segment_index(segment_index)
414+
spike_vector = self.to_spike_vector(concatenated=False)
415+
spikes_in_segment = spike_vector[segment_index]
416+
if len(spikes_in_segment) == 0:
417+
return 0
418+
return int(np.max(spikes_in_segment["sample_index"]))
419+
355420
def get_times(self, segment_index=None):
356421
"""
357422
Get time vector for a registered recording segment.

src/spikeinterface/core/tests/test_time_handling.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,61 @@ def test_shift_times_with_None_as_t_start():
445445
assert recording.segments[0].t_start is None
446446
recording.shift_times(shift=1.0) # Shift by one seconds should not generate an error
447447
assert recording.get_start_time() == 1.0
448+
449+
450+
class TestSortingTimeNoRecording:
451+
"""Tests for time methods on BaseSorting without a registered recording."""
452+
453+
def test_get_start_time_default(self):
454+
sorting = generate_sorting(num_units=5, durations=[10])
455+
assert sorting.get_start_time(segment_index=0) == 0.0
456+
457+
def test_get_end_time_is_last_spike(self):
458+
sorting = generate_sorting(num_units=5, durations=[10])
459+
last_frame = sorting.get_last_spike_frame(segment_index=0)
460+
expected_time = last_frame / sorting.get_sampling_frequency()
461+
assert sorting.get_end_time(segment_index=0) == expected_time
462+
463+
def test_get_start_time_with_t_start(self):
464+
sorting = generate_sorting(num_units=5, durations=[10])
465+
sorting.segments[0]._t_start = 100.0
466+
assert sorting.get_start_time(segment_index=0) == 100.0
467+
468+
469+
class TestSortingTimeWithRecording:
470+
"""
471+
Tests for time methods on BaseSorting with a registered recording.
472+
The key invariant: the recording is the source of truth for timestamps.
473+
"""
474+
475+
def test_get_start_end_time(self):
476+
recording = generate_recording(num_channels=4, durations=[10])
477+
sorting = generate_sorting(num_units=5, durations=[10])
478+
sorting.register_recording(recording)
479+
480+
assert sorting.get_start_time(segment_index=0) == recording.get_start_time(segment_index=0)
481+
assert sorting.get_end_time(segment_index=0) == recording.get_end_time(segment_index=0)
482+
483+
def test_register_recording_copies_start_times(self):
484+
"""Registering a recording copies its start times into the sorting segments."""
485+
sorting = generate_sorting(num_units=5, durations=[10])
486+
sorting.segments[0]._t_start = 100.0
487+
488+
recording = generate_recording(num_channels=4, durations=[10])
489+
recording.shift_times(shift=50.0)
490+
sorting.register_recording(recording)
491+
492+
# _t_start now mirrors the recording's start time, preserving it across
493+
# save/load cycles even when the recording is not attached.
494+
assert sorting.segments[0]._t_start == recording.get_start_time(segment_index=0)
495+
assert sorting.get_start_time(segment_index=0) == 50.0
496+
497+
def test_with_recording_shifted_start(self):
498+
"""Recording with a non-zero t_start is reflected in the sorting."""
499+
recording = generate_recording(num_channels=4, durations=[10])
500+
recording.shift_times(shift=50.0)
501+
502+
sorting = generate_sorting(num_units=5, durations=[10])
503+
sorting.register_recording(recording)
504+
505+
assert sorting.get_start_time(segment_index=0) == 50.0

0 commit comments

Comments
 (0)