Skip to content

Commit d337cf7

Browse files
Improve get_times function to slice timestamps (#4512)
Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com>
1 parent 0f3cd0f commit d337cf7

6 files changed

Lines changed: 90 additions & 66 deletions

File tree

src/spikeinterface/core/basesorting.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,12 @@ def get_last_spike_frame(self, segment_index: int | None = None) -> int:
417417
return 0
418418
return int(np.max(spikes_in_segment["sample_index"]))
419419

420-
def get_times(self, segment_index=None):
420+
def get_times(
421+
self,
422+
segment_index: int | None = None,
423+
start_frame: int | None = None,
424+
end_frame: int | None = None,
425+
):
421426
"""
422427
Get time vector for a registered recording segment.
423428
@@ -429,7 +434,7 @@ def get_times(self, segment_index=None):
429434
"""
430435
segment_index = self._check_segment_index(segment_index)
431436
if self.has_recording():
432-
return self._recording.get_times(segment_index=segment_index)
437+
return self._recording.get_times(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame)
433438
else:
434439
return None
435440

src/spikeinterface/core/tests/test_baserecording.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def test_BaseRecording(create_cache_folder):
8383
assert values.dtype.kind == "i"
8484

8585
times0 = rec.get_times(segment_index=0)
86+
times0_slice = rec.get_times(segment_index=0, start_frame=10, end_frame=20)
87+
assert np.allclose(times0_slice, times0[10:20])
8688

8789
# dump/load dict
8890
d = rec.to_dict(include_annotations=True, include_properties=True)

src/spikeinterface/core/tests/test_time_handling.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,29 @@ def test_shift_times_with_None_as_t_start():
447447
assert recording.get_start_time() == 1.0
448448

449449

450+
def test_get_times_with_time_vector_slicing():
451+
sampling_frequency = 10_000.0
452+
recording = generate_recording(durations=[1.0], num_channels=3, sampling_frequency=sampling_frequency)
453+
times = 1.0 + np.arange(0, 10_000) / sampling_frequency
454+
recording.set_times(times=times, segment_index=0, with_warning=False)
455+
456+
# Full get_times should return the complete time vector
457+
times_full = recording.get_times(segment_index=0)
458+
assert np.allclose(times_full, times)
459+
460+
# Sliced get_times should match slicing the full vector
461+
times_slice = recording.get_times(segment_index=0, start_frame=1000, end_frame=8000)
462+
assert np.allclose(times_slice, times[1000:8000])
463+
464+
# Only start_frame provided
465+
times_from_start = recording.get_times(segment_index=0, start_frame=5000)
466+
assert np.allclose(times_from_start, times[5000:])
467+
468+
# Only end_frame provided
469+
times_to_end = recording.get_times(segment_index=0, end_frame=3000)
470+
assert np.allclose(times_to_end, times[:3000])
471+
472+
450473
class TestSortingTimeNoRecording:
451474
"""Tests for time methods on BaseSorting without a registered recording."""
452475

src/spikeinterface/core/time_series.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -382,14 +382,20 @@ def __init__(self, sampling_frequency=None, t_start=None, time_vector=None):
382382
BaseSegment.__init__(self)
383383

384384
def get_times(self, start_frame: int | None = None, end_frame: int | None = None) -> np.ndarray:
385-
if start_frame is None:
386-
start_frame = 0
387-
if end_frame is None:
388-
end_frame = self.get_num_samples()
389385
if self.time_vector is not None:
390-
self.time_vector = np.asarray(self.time_vector)
391-
return self.time_vector[start_frame:end_frame]
386+
# Cache full times as numpy if start_frame and end_frame are None. If the user passes start_frame and
387+
# end_frame, we slice the time vector and return the sliced version as numpy array.
388+
# This is useful for very long recordings, where the full time vector might be too large to fit in memory.
389+
if start_frame is None and end_frame is None:
390+
self.time_vector = np.asarray(self.time_vector)
391+
return self.time_vector
392+
else:
393+
start_frame = int(start_frame) if start_frame is not None else 0
394+
end_frame = int(end_frame) if end_frame is not None else self.get_num_samples()
395+
return np.asarray(self.time_vector[start_frame:end_frame])
392396
else:
397+
start_frame = int(start_frame) if start_frame is not None else 0
398+
end_frame = int(end_frame) if end_frame is not None else self.get_num_samples()
393399
time_vector = np.arange(start_frame, end_frame, dtype="float64")
394400
time_vector /= self.sampling_frequency
395401
if self.t_start is not None:

src/spikeinterface/widgets/traces.py

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,6 @@ def __init__(
143143
raise ValueError('You must provide "segment_index" for multisegment recordings.')
144144
segment_index = 0
145145

146-
if not rec0.has_time_vector(segment_index=segment_index):
147-
times = None
148-
else:
149-
times = rec0.get_times(segment_index=segment_index)
150146
t_start = rec0.get_start_time(segment_index=segment_index)
151147
t_end = rec0.get_end_time(segment_index=segment_index)
152148

@@ -172,7 +168,7 @@ def __init__(
172168
cmap = cmap
173169

174170
times_in_range, list_traces, frame_range, channel_ids = _get_trace_list(
175-
recordings, channel_ids, time_range, segment_index, return_in_uV=return_in_uV, times=times
171+
recordings, channel_ids, segment_index, time_range=time_range, return_in_uV=return_in_uV
176172
)
177173

178174
list_traces = [traces * scale for traces in list_traces]
@@ -405,25 +401,12 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
405401
self.figure.canvas.header_visible = False
406402
plt.show()
407403

408-
if not self.rec0.has_time_vector(segment_index=data_plot["segment_index"]):
409-
times = None
410-
t_starts = [
411-
rec0.get_start_time(segment_index=segment_index) for segment_index in range(rec0.get_num_segments())
412-
]
413-
else:
414-
times = [
415-
np.array(self.rec0.get_times(segment_index=segment_index))
416-
for segment_index in range(self.rec0.get_num_segments())
417-
]
418-
t_starts = None
419-
420404
# some widgets
421405
self.time_slider = TimeSlider(
422406
durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())],
423407
sampling_frequency=rec0.sampling_frequency,
424-
time_range=data_plot["time_range"],
425-
times=times,
426-
t_starts=t_starts,
408+
frame_range=data_plot["frame_range"],
409+
rec0=rec0,
427410
)
428411
# handle times
429412
if data_plot["events"] is not None:
@@ -559,24 +542,17 @@ def _retrieve_traces(self, change=None):
559542

560543
start_frame, end_frame, segment_index = self.time_slider.value
561544

562-
if not self.rec0.has_time_vector(segment_index=segment_index):
563-
times = None
564-
time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency + self.rec0.get_start_time(
565-
segment_index=segment_index
566-
)
567-
else:
568-
times = self.rec0.get_times(segment_index=segment_index)
569-
time_range = np.array([times[start_frame], times[end_frame]])
545+
frame_range = np.array([start_frame, end_frame])
570546

571547
self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()}
572548
times_in_range, list_traces, frame_range, channel_ids = _get_trace_list(
573549
self._selected_recordings,
574550
channel_ids,
575-
time_range,
576551
segment_index,
577552
return_in_uV=self.return_in_uV,
578-
times=times,
553+
frame_range=frame_range,
579554
)
555+
time_range = np.array([times_in_range[0], times_in_range[-1]])
580556

581557
self._channel_ids = channel_ids
582558
self._list_traces = list_traces
@@ -640,12 +616,11 @@ def plot_figpack(self, data_plot, **backend_kwargs):
640616
handle_display_and_url,
641617
import_figpack_or_sortingview,
642618
)
619+
import importlib.util
643620

644621
use_sortingview = backend_kwargs.get("use_sortingview", False)
645622
vv_base, vv_views = import_figpack_or_sortingview(use_sortingview)
646623

647-
import importlib.util
648-
649624
spec = importlib.util.find_spec("pyvips")
650625
if spec is None:
651626
raise ImportError("To use `plot_traces()` in sortingview you need the pyvips package.")
@@ -705,25 +680,28 @@ def plot_ephyviewer(self, data_plot, **backend_kwargs):
705680
app.exec()
706681

707682

708-
def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_in_uV=False, times=None):
683+
def _get_trace_list(recordings, channel_ids, segment_index, time_range=None, return_in_uV=False, frame_range=None):
709684
# function also used in ipywidgets plotter
710685
k0 = list(recordings.keys())[0]
711686
rec0 = recordings[k0]
712687

713-
fs = rec0.get_sampling_frequency()
714-
715688
if return_in_uV:
716689
assert all(
717690
rec.has_scaleable_traces() for rec in recordings.values()
718691
), "Some recording layers do not have scaled traces. Use `return_in_uV=False`"
719-
if times is not None:
720-
frame_range = np.searchsorted(times, time_range)
721-
times = times[frame_range[0] : frame_range[1]]
722-
else:
723-
frame_range = rec0.time_to_sample_index(time_range, segment_index=segment_index)
692+
693+
assert time_range is not None or frame_range is not None, "You must provide either time_range or frame_range"
694+
695+
if frame_range is None:
696+
# use the sampling-frequency approximation to avoid loading the full time vector
697+
t_start = rec0.get_start_time(segment_index=segment_index)
698+
fs = rec0.get_sampling_frequency()
699+
frame_range = np.round((np.asarray(time_range) - t_start) * fs).astype(np.int64)
724700
a_max = rec0.get_num_frames(segment_index=segment_index)
725701
frame_range = np.clip(frame_range, 0, a_max)
726-
times = np.arange(frame_range[0], frame_range[1]) / fs + rec0.get_start_time(segment_index=segment_index)
702+
703+
# lazily load only the needed time slice
704+
times_in_range = rec0.get_times(segment_index=segment_index, start_frame=frame_range[0], end_frame=frame_range[1])
727705

728706
list_traces = []
729707
for rec_name, rec in recordings.items():
@@ -737,4 +715,4 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_i
737715

738716
list_traces.append(traces)
739717

740-
return times, list_traces, frame_range, channel_ids
718+
return times_in_range, list_traces, frame_range, channel_ids

src/spikeinterface/widgets/utils_ipywidgets.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,25 @@ def check_ipywidget_backend():
1616
class TimeSlider(W.HBox):
1717
value = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int())
1818

19-
def __init__(self, durations, sampling_frequency, time_range, times=None, t_starts=None, **kwargs):
19+
def __init__(self, durations, sampling_frequency, frame_range, rec0=None, t_starts=None, **kwargs):
2020
self.num_segments = len(durations)
2121
self.frame_limits = [int(sampling_frequency * d) for d in durations]
2222
self.sampling_frequency = sampling_frequency
2323
self.segment_index = 0
2424

25-
if times is not None:
26-
assert len(times) == len(durations), "times should be a list of arrays with one array per segment"
27-
times_segment = times[self.segment_index]
28-
start_frame, end_frame = np.searchsorted(times_segment, time_range)
29-
self.times = times
25+
start_frame, end_frame = int(frame_range[0]), int(frame_range[1])
26+
27+
if rec0 is not None:
28+
self.rec0 = rec0
3029
self.t_starts = None
3130
else:
3231
assert t_starts is not None
33-
t_start_segment = t_starts[self.segment_index]
34-
start_frame = int((time_range[0] - t_start_segment) * sampling_frequency)
35-
end_frame = int((time_range[1] - t_start_segment) * sampling_frequency)
36-
self.times = None
32+
self.rec0 = None
3733
self.t_starts = t_starts
3834

3935
self.frame_range = (start_frame, end_frame)
4036

41-
self.value = (int(start_frame), int(end_frame), self.segment_index)
37+
self.value = (start_frame, end_frame, self.segment_index)
4238

4339
layout = W.Layout(align_items="center", width="2.5cm", height="1.cm")
4440
but_left = W.Button(description="", disabled=False, button_style="", icon="arrow-left", layout=layout)
@@ -63,8 +59,16 @@ def __init__(self, durations, sampling_frequency, time_range, times=None, t_star
6359
)
6460

6561
# DatetimePicker is only for ipywidget v8 (which is not working in vscode 2023-03)
62+
if self.rec0 is not None:
63+
initial_time = float(
64+
self.rec0.get_times(
65+
segment_index=self.segment_index, start_frame=start_frame, end_frame=start_frame + 1
66+
)[0]
67+
)
68+
else:
69+
initial_time = start_frame / sampling_frequency + self.t_starts[self.segment_index]
6670
self.time_label = W.Text(
67-
value=f"{time_range[0]}", description="", disabled=False, layout=W.Layout(width="2.5cm")
71+
value=f"{initial_time}", description="", disabled=False, layout=W.Layout(width="2.5cm")
6872
)
6973
self.time_label.observe(self.time_label_changed, names="value", type="change")
7074

@@ -137,8 +141,10 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update
137141
if new_frame is None and new_time is None:
138142
start_frame = self.slider.value
139143
elif new_frame is None:
140-
if self.times is not None:
141-
start_frame = int(np.searchsorted(self.times[self.segment_index], [new_time])[0])
144+
if self.rec0 is not None:
145+
# approximate via sampling frequency to avoid loading the full time vector
146+
t_start = float(self.rec0.get_start_time(segment_index=self.segment_index))
147+
start_frame = int((new_time - t_start) * self.sampling_frequency)
142148
else:
143149
start_frame = int((new_time - self.t_starts[self.segment_index]) * self.sampling_frequency)
144150
else:
@@ -153,8 +159,12 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update
153159

154160
end_frame = min(self.frame_limits[self.segment_index], end_frame)
155161

156-
if self.times is not None:
157-
start_time = self.times[self.segment_index][start_frame]
162+
if self.rec0 is not None:
163+
start_time = float(
164+
self.rec0.get_times(
165+
segment_index=self.segment_index, start_frame=start_frame, end_frame=start_frame + 1
166+
)[0]
167+
)
158168
else:
159169
start_time = start_frame / self.sampling_frequency + self.t_starts[self.segment_index]
160170

0 commit comments

Comments
 (0)