Skip to content

Commit b1911bf

Browse files
committed
add tests and fixes
1 parent 0efad83 commit b1911bf

2 files changed

Lines changed: 37 additions & 4 deletions

File tree

src/spikeinterface/core/basesorting.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def get_unit_spike_trains(
350350
)
351351

352352
segment_index = self._check_segment_index(segment_index)
353+
segment = self.segments[segment_index]
353354
if use_cache:
354355
# TODO: speed things up
355356
ordered_spike_vector, slices = self.to_reordered_spike_vector(
@@ -372,8 +373,9 @@ def get_unit_spike_trains(
372373
spike_trains[unit_id] = spike_frames
373374
else:
374375
spike_trains = segment.get_unit_spike_trains(
375-
unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame, return_times=return_times
376+
unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame
376377
)
378+
return spike_trains
377379

378380
def get_unit_spike_trains_in_seconds(
379381
self,
@@ -453,10 +455,10 @@ def get_unit_spike_trains_in_seconds(
453455
use_cache=use_cache,
454456
)
455457
for unit_id in unit_ids:
456-
spike_frames = spike_frames[unit_id]
458+
spike_frames_unit = spike_frames[unit_id]
457459
t_start = segment._t_start if segment._t_start is not None else 0
458-
spike_times[unit_id] = spike_frames / self.get_sampling_frequency()
459-
return t_start + spike_times
460+
spike_times[unit_id] = spike_frames_unit / self.get_sampling_frequency() + t_start
461+
return spike_times
460462

461463
def register_recording(self, recording, check_spike_frames: bool = True):
462464
"""

src/spikeinterface/core/tests/test_basesorting.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,37 @@ def test_select_periods():
310310
np.testing.assert_array_equal(sliced_sorting.to_spike_vector(), sliced_sorting_array.to_spike_vector())
311311

312312

313+
@pytest.mark.parametrize("use_cache", [False, True])
314+
def test_get_unit_spike_trains(use_cache):
315+
sampling_frequency = 10_000.0
316+
duration = 1.0
317+
num_samples = int(sampling_frequency * duration)
318+
num_units = 10
319+
sorting = generate_sorting(durations=[duration], sampling_frequency=sampling_frequency, num_units=num_units)
320+
321+
all_spike_trains = sorting.get_unit_spike_trains(unit_ids=sorting.unit_ids, use_cache=use_cache)
322+
assert isinstance(all_spike_trains, dict)
323+
assert set(all_spike_trains.keys()) == set(sorting.unit_ids)
324+
for unit_id in sorting.unit_ids:
325+
spiketrain = sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id, use_cache=use_cache)
326+
assert np.array_equal(all_spike_trains[unit_id], spiketrain)
327+
328+
# test with times
329+
spike_trains_times = sorting.get_unit_spike_trains_in_seconds(
330+
unit_ids=sorting.unit_ids, return_times=True, use_cache=use_cache
331+
)
332+
assert isinstance(spike_trains_times, dict)
333+
assert set(spike_trains_times.keys()) == set(sorting.unit_ids)
334+
for unit_id in sorting.unit_ids:
335+
spiketrain = sorting.get_unit_spike_train(
336+
segment_index=0, unit_id=unit_id, use_cache=use_cache, return_times=True
337+
)
338+
spiketrain_times = sorting.get_unit_spike_train_in_seconds(
339+
segment_index=0, unit_id=unit_id, use_cache=use_cache
340+
)
341+
assert np.allclose(spiketrain_times, spiketrain)
342+
343+
313344
if __name__ == "__main__":
314345
import tempfile
315346

0 commit comments

Comments
 (0)