Skip to content

Commit 0e5226c

Browse files
Add shift_times to BaseSorting (#4551)
Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent d337cf7 commit 0e5226c

7 files changed

Lines changed: 223 additions & 32 deletions

File tree

src/spikeinterface/core/basesorting.py

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,18 @@ def get_unit_spike_train_in_seconds(
278278

279279
# Use the native spiking times if available
280280
# Some instances might implement a method themselves to access spike times directly without having to convert
281-
# (e.g. NWB extractors)
281+
# (e.g. NWB extractors). The native times already include the extractor's `_native_t_start`,
282+
# so we apply only the shift (`_t_start - _native_t_start`) on top.
282283
if hasattr(segment, "get_unit_spike_train_in_seconds"):
283-
return segment.get_unit_spike_train_in_seconds(unit_id=unit_id, start_time=start_time, end_time=end_time)
284+
spike_times = segment.get_unit_spike_train_in_seconds(
285+
unit_id=unit_id, start_time=start_time, end_time=end_time
286+
)
287+
t_start = segment._t_start if segment._t_start is not None else 0
288+
native_t_start = segment._native_t_start if segment._native_t_start is not None else 0
289+
shift = t_start - native_t_start
290+
if shift != 0:
291+
spike_times = spike_times + shift
292+
return spike_times
284293

285294
# If no recording attached and all back to frame-based conversion
286295
# Get spike train in frames and convert to times using traditional method
@@ -330,8 +339,12 @@ def register_recording(self, recording, check_spike_frames: bool = True):
330339
# Copy the recording's start times into the sorting segments. This way,
331340
# the sorting preserves the start time even if the recording is later
332341
# detached (e.g. analyzer saved and reloaded without the recording).
342+
# Also update `_native_t_start` so any subsequent `shift_times` call measures
343+
# its delta from the recording's start time (not the extractor's original value).
333344
for segment_index, segment in enumerate(self.segments):
334-
segment._t_start = recording.get_start_time(segment_index=segment_index)
345+
start_time = recording.get_start_time(segment_index=segment_index)
346+
segment._t_start = start_time
347+
segment._native_t_start = start_time
335348

336349
@property
337350
def sorting_info(self):
@@ -374,11 +387,38 @@ def get_start_time(self, segment_index: int | None = None) -> float:
374387
segment = self.segments[segment_index]
375388
return segment._t_start if segment._t_start is not None else 0.0
376389

390+
def shift_times(self, shift: int | float, segment_index: int | None = None) -> None:
391+
"""
392+
Shift all times by a scalar value.
393+
394+
This modifies the sorting's own time offset without touching the registered
395+
recording. When a recording is registered, the shift is applied on top of
396+
the recording's time basis when resolving timestamps.
397+
398+
Parameters
399+
----------
400+
shift : int | float
401+
The shift to apply. If positive, times will be increased by `shift`.
402+
If negative, times will be decreased.
403+
segment_index : int | None
404+
The segment on which to shift the times.
405+
If `None`, all segments will be shifted.
406+
"""
407+
if segment_index is None:
408+
segments_to_shift = range(self.get_num_segments())
409+
else:
410+
segments_to_shift = (segment_index,)
411+
412+
for segment_index in segments_to_shift:
413+
segment = self.segments[segment_index]
414+
segment._t_start = (segment._t_start if segment._t_start is not None else 0) + shift
415+
377416
def get_end_time(self, segment_index: int | None = None) -> float:
378417
"""Get the end time of the sorting segment.
379418
380-
If a recording is registered, returns the recording's end time.
381-
Otherwise returns the time of the last spike in the segment.
419+
If a recording is registered, returns the recording's end time (plus any
420+
shift applied via `shift_times`). Otherwise returns the time of the last
421+
spike in the segment.
382422
383423
Parameters
384424
----------
@@ -392,7 +432,10 @@ def get_end_time(self, segment_index: int | None = None) -> float:
392432
"""
393433
segment_index = self._check_segment_index(segment_index)
394434
if self.has_recording():
395-
return self._recording.get_end_time(segment_index=segment_index)
435+
segment = self.segments[segment_index]
436+
t_start = segment._t_start if segment._t_start is not None else 0
437+
shift = t_start - self._recording.get_start_time(segment_index=segment_index)
438+
return self._recording.get_end_time(segment_index=segment_index) + shift
396439
else:
397440
last_spike_frame = self.get_last_spike_frame(segment_index=segment_index)
398441
return self.sample_index_to_time(last_spike_frame, segment_index=segment_index)
@@ -430,11 +473,19 @@ def get_times(
430473
* if the segment has a time_vector, then it is returned
431474
* if not, a time_vector is constructed on the fly with sampling frequency
432475
476+
Any shift applied via `shift_times` is added to the returned times.
477+
433478
If there is no registered recording it returns None
434479
"""
435480
segment_index = self._check_segment_index(segment_index)
436481
if self.has_recording():
437-
return self._recording.get_times(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame)
482+
times = self._recording.get_times(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame)
483+
segment = self.segments[segment_index]
484+
t_start = segment._t_start if segment._t_start is not None else 0
485+
shift = t_start - self._recording.get_start_time(segment_index=segment_index)
486+
if shift != 0:
487+
times = times + shift
488+
return times
438489
else:
439490
return None
440491

@@ -776,11 +827,13 @@ def time_to_sample_index(self, time, segment_index=0):
776827
"""
777828
Transform time in seconds into sample index
778829
"""
830+
segment = self.segments[segment_index]
831+
t_start = segment._t_start if segment._t_start is not None else 0
779832
if self.has_recording():
780-
sample_index = self._recording.time_to_sample_index(time, segment_index=segment_index)
833+
# Subtract the sorting's shift (relative to the recording's start) before delegating
834+
shift = t_start - self._recording.get_start_time(segment_index=segment_index)
835+
sample_index = self._recording.time_to_sample_index(time - shift, segment_index=segment_index)
781836
else:
782-
segment = self.segments[segment_index]
783-
t_start = segment._t_start if segment._t_start is not None else 0
784837
sample_index = round((time - t_start) * self.get_sampling_frequency())
785838

786839
return sample_index
@@ -792,11 +845,13 @@ def sample_index_to_time(
792845
Transform sample index into time in seconds
793846
"""
794847
segment_index = self._check_segment_index(segment_index)
848+
segment = self.segments[segment_index]
849+
t_start = segment._t_start if segment._t_start is not None else 0
795850
if self.has_recording():
796-
return self._recording.sample_index_to_time(sample_index, segment_index=segment_index)
851+
# Add the sorting's shift (relative to the recording's start) after delegating
852+
shift = t_start - self._recording.get_start_time(segment_index=segment_index)
853+
return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) + shift
797854
else:
798-
segment = self.segments[segment_index]
799-
t_start = segment._t_start if segment._t_start is not None else 0
800855
return (sample_index / self.get_sampling_frequency()) + t_start
801856

802857
def precompute_spike_trains(self):
@@ -1154,6 +1209,11 @@ class BaseSortingSegment(BaseSegment):
11541209

11551210
def __init__(self, t_start=None):
11561211
self._t_start = t_start
1212+
# Immutable reference to the start time as set by the extractor at init.
1213+
# Used to compute the user-applied shift as `_t_start - _native_t_start`,
1214+
# so `shift_times` can correctly propagate through extractors that return
1215+
# native absolute times (e.g. NWB) without double-counting the extractor's offset.
1216+
self._native_t_start = t_start
11571217
BaseSegment.__init__(self)
11581218

11591219
def get_unit_spike_train(

src/spikeinterface/core/frameslicesorting.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike
7676

7777
# link sorting segment
7878
parent_segment = parent_sorting.segments[0]
79-
sub_segment = FrameSliceSortingSegment(parent_segment, start_frame, end_frame)
79+
sub_segment = FrameSliceSortingSegment(
80+
parent_segment, start_frame, end_frame, sampling_frequency=parent_sorting.get_sampling_frequency()
81+
)
8082
self.add_sorting_segment(sub_segment)
8183

8284
# copy properties and annotations
@@ -96,8 +98,13 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike
9698

9799

98100
class FrameSliceSortingSegment(BaseSortingSegment):
99-
def __init__(self, parent_sorting_segment, start_frame, end_frame):
100-
BaseSortingSegment.__init__(self)
101+
def __init__(self, parent_sorting_segment, start_frame, end_frame, sampling_frequency):
102+
# Propagate the parent's start time forward by the slice offset, mirroring
103+
# what FrameSliceRecordingSegment does. A parent with `_t_start=None` is
104+
# treated as starting at 0, so the slice gets a concrete `start_frame / fs`.
105+
parent_t_start = parent_sorting_segment._t_start if parent_sorting_segment._t_start is not None else 0.0
106+
t_start = parent_t_start + start_frame / sampling_frequency
107+
BaseSortingSegment.__init__(self, t_start=t_start)
101108
self._parent_sorting_segment = parent_sorting_segment
102109
self.start_frame = start_frame
103110
self.end_frame = end_frame

src/spikeinterface/core/generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def generate_sorting(
193193
if t_starts is not None:
194194
assert len(t_starts) == len(durations), "t_starts must have the same length as durations"
195195
for segment_index, t_start in enumerate(t_starts):
196-
sorting.segments[segment_index]._t_start = t_start
196+
segment = sorting.segments[segment_index]
197+
segment._t_start = float(t_start)
198+
segment._native_t_start = float(t_start)
197199

198200
return sorting
199201

src/spikeinterface/core/tests/test_frameslicesorting.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,5 +91,25 @@ def test_FrameSliceSorting():
9191
assert_raises(Exception, sorting_exceeding.frame_slice, None, None)
9292

9393

94+
def test_time_slice_propagates_t_start():
95+
"""`time_slice` goes through `frame_slice`, so the propagated start time should
96+
equal the requested `start_time`. Covers both the parent-with-no-t_start case
97+
and the parent-with-explicit-t_start case (which should stack)."""
98+
sf = 10.0
99+
spike_times = {"0": np.arange(100, 900)}
100+
101+
# Parent has no explicit t_start (treated as 0).
102+
sorting = NumpySorting.from_unit_dict([spike_times], sf)
103+
sub = sorting.time_slice(start_time=20.0, end_time=50.0)
104+
assert sub.get_start_time(segment_index=0) == 20.0
105+
106+
# Parent has an explicit t_start; the slice offset stacks on top.
107+
sorting_shifted = NumpySorting.from_unit_dict([spike_times], sf)
108+
sorting_shifted.shift_times(shift=100.0)
109+
sub_shifted = sorting_shifted.time_slice(start_time=120.0, end_time=150.0)
110+
assert sub_shifted.get_start_time(segment_index=0) == 120.0
111+
112+
94113
if __name__ == "__main__":
95114
test_FrameSliceSorting()
115+
test_time_slice_propagates_t_start()

src/spikeinterface/core/tests/test_time_handling.py

Lines changed: 115 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,8 @@ def test_sorting_analyzer_get_durations_from_recording(self, time_vector_recordi
286286
"""
287287
_, times_recording, _ = time_vector_recording
288288

289-
sorting = si.generate_sorting(
290-
durations=[times_recording.get_duration(s) for s in range(times_recording.get_num_segments())]
291-
)
289+
durations = [times_recording.get_duration(s) for s in range(times_recording.get_num_segments())]
290+
sorting = si.generate_sorting(durations=durations)
292291
sorting_analyzer = si.create_sorting_analyzer(sorting, recording=times_recording)
293292

294293
assert np.array_equal(sorting_analyzer.get_total_duration(), times_recording.get_total_duration())
@@ -484,10 +483,51 @@ def test_get_end_time_is_last_spike(self):
484483
assert sorting.get_end_time(segment_index=0) == expected_time
485484

486485
def test_get_start_time_with_t_start(self):
487-
sorting = generate_sorting(num_units=5, durations=[10])
488-
sorting.segments[0]._t_start = 100.0
486+
sorting = generate_sorting(num_units=5, durations=[10], t_starts=[100.0])
489487
assert sorting.get_start_time(segment_index=0) == 100.0
490488

489+
def test_shift_times(self):
490+
sorting = generate_sorting(num_units=5, durations=[10])
491+
unit_id = sorting.unit_ids[0]
492+
493+
spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
494+
495+
sorting.shift_times(shift=5.0)
496+
497+
assert sorting.get_start_time(segment_index=0) == 5.0
498+
spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
499+
assert np.allclose(spike_times_after, spike_times_before + 5.0)
500+
501+
def test_shift_times_all_segments(self):
502+
sorting = generate_sorting(num_units=5, durations=[10, 15], t_starts=[1.0, 2.0])
503+
504+
sorting.shift_times(shift=3.0)
505+
506+
assert sorting.get_start_time(segment_index=0) == 4.0
507+
assert sorting.get_start_time(segment_index=1) == 5.0
508+
509+
def test_shift_times_single_segment(self):
510+
sorting = generate_sorting(num_units=5, durations=[10, 15], t_starts=[1.0, 2.0])
511+
512+
sorting.shift_times(shift=3.0, segment_index=1)
513+
514+
assert sorting.get_start_time(segment_index=0) == 1.0
515+
assert sorting.get_start_time(segment_index=1) == 5.0
516+
517+
def test_shift_times_with_native_spike_times(self):
518+
"""Shift must apply even when the segment provides native spike times (e.g. NWB extractors)."""
519+
sorting = generate_sorting(num_units=5, durations=[10])
520+
unit_id = sorting.unit_ids[0]
521+
segment = sorting.segments[0]
522+
523+
# Simulate a segment that provides native spike times directly
524+
original_times = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True).copy()
525+
segment.get_unit_spike_train_in_seconds = lambda unit_id, start_time, end_time: original_times
526+
527+
sorting.shift_times(shift=5.0)
528+
spike_times = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
529+
assert np.allclose(spike_times, original_times + 5.0)
530+
491531

492532
class TestSortingTimeWithRecording:
493533
"""
@@ -504,17 +544,16 @@ def test_get_start_end_time(self):
504544
assert sorting.get_end_time(segment_index=0) == recording.get_end_time(segment_index=0)
505545

506546
def test_register_recording_copies_start_times(self):
507-
"""Registering a recording copies its start times into the sorting segments."""
508-
sorting = generate_sorting(num_units=5, durations=[10])
509-
sorting.segments[0]._t_start = 100.0
547+
"""Registering a recording overrides any pre-existing sorting start time."""
548+
sorting = generate_sorting(num_units=5, durations=[10], t_starts=[100.0])
510549

511550
recording = generate_recording(num_channels=4, durations=[10])
512551
recording.shift_times(shift=50.0)
513552
sorting.register_recording(recording)
514553

515-
# _t_start now mirrors the recording's start time, preserving it across
516-
# save/load cycles even when the recording is not attached.
517-
assert sorting.segments[0]._t_start == recording.get_start_time(segment_index=0)
554+
# The sorting's start time now mirrors the recording's start time, preserving it
555+
# across save/load cycles even when the recording is later detached.
556+
assert sorting.get_start_time(segment_index=0) == recording.get_start_time(segment_index=0)
518557
assert sorting.get_start_time(segment_index=0) == 50.0
519558

520559
def test_with_recording_shifted_start(self):
@@ -526,3 +565,68 @@ def test_with_recording_shifted_start(self):
526565
sorting.register_recording(recording)
527566

528567
assert sorting.get_start_time(segment_index=0) == 50.0
568+
569+
def test_shift_times(self):
570+
recording = generate_recording(num_channels=4, durations=[10])
571+
sorting = generate_sorting(num_units=5, durations=[10])
572+
sorting.register_recording(recording)
573+
unit_id = sorting.unit_ids[0]
574+
575+
rec_start_before = recording.get_start_time(segment_index=0)
576+
rec_end_before = recording.get_end_time(segment_index=0)
577+
spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
578+
579+
sorting.shift_times(shift=5.0)
580+
581+
# The recording should be untouched
582+
assert recording.get_start_time(segment_index=0) == rec_start_before
583+
assert recording.get_end_time(segment_index=0) == rec_end_before
584+
585+
# The sorting's times should be shifted
586+
assert sorting.get_start_time(segment_index=0) == rec_start_before + 5.0
587+
assert sorting.get_end_time(segment_index=0) == rec_end_before + 5.0
588+
spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
589+
assert np.allclose(spike_times_after, spike_times_before + 5.0)
590+
591+
def test_time_conversion_roundtrip_after_shift(self):
592+
"""sample_index_to_time and time_to_sample_index must remain inverses after a shift."""
593+
recording = generate_recording(num_channels=4, durations=[10])
594+
sorting = generate_sorting(num_units=5, durations=[10])
595+
sorting.register_recording(recording)
596+
597+
sorting.shift_times(shift=5.0)
598+
599+
# Frame 30000 is 1.0s in the recording. After a 5.0s shift, the sorting should report 6.0s.
600+
time = sorting.sample_index_to_time(30000, segment_index=0)
601+
assert time == recording.sample_index_to_time(30000, segment_index=0) + 5.0
602+
603+
# The inverse: 6.0s in the sorting should map back to frame 30000.
604+
frame = sorting.time_to_sample_index(time, segment_index=0)
605+
assert frame == 30000
606+
607+
def test_shift_times_with_time_vector(self):
608+
"""Shift on sorting composes with a recording that has an explicit time vector,
609+
preserving the irregular spacing."""
610+
recording = generate_recording(num_channels=4, durations=[1.0])
611+
num_samples = recording.get_num_samples(segment_index=0)
612+
# Irregular timestamps starting at 100.0
613+
times = (
614+
100.0
615+
+ np.cumsum(np.random.RandomState(0).uniform(0.5, 1.5, num_samples)) / recording.get_sampling_frequency()
616+
)
617+
recording.set_times(times, segment_index=0, with_warning=False)
618+
619+
sorting = generate_sorting(num_units=5, durations=[1.0])
620+
sorting.register_recording(recording)
621+
unit_id = sorting.unit_ids[0]
622+
623+
spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
624+
625+
sorting.shift_times(shift=5.0)
626+
627+
spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
628+
# Irregular spacing preserved, everything shifted by 5.0
629+
assert np.allclose(spike_times_after, spike_times_before + 5.0)
630+
631+
# Recording is untouched
632+
assert np.allclose(recording.get_times(segment_index=0), times)

src/spikeinterface/extractors/neoextractors/neobaseextractor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,11 +618,10 @@ def __init__(
618618
sampling_frequency,
619619
neo_returns_frames,
620620
):
621-
BaseSortingSegment.__init__(self)
621+
BaseSortingSegment.__init__(self, t_start=t_start)
622622
self.neo_reader = neo_reader
623623
self.segment_index = segment_index
624624
self.block_index = block_index
625-
self._t_start = t_start
626625
self._sampling_frequency = sampling_frequency
627626
self.neo_returns_frames = neo_returns_frames
628627

0 commit comments

Comments
 (0)