@@ -310,6 +310,37 @@ def test_select_periods():
310310 np .testing .assert_array_equal (sliced_sorting .to_spike_vector (), sliced_sorting_array .to_spike_vector ())
311311
312312
313+ @pytest .mark .parametrize ("use_cache" , [False , True ])
314+ def test_get_unit_spike_trains (use_cache ):
315+ sampling_frequency = 10_000.0
316+ duration = 1.0
317+ num_samples = int (sampling_frequency * duration )
318+ num_units = 10
319+ sorting = generate_sorting (durations = [duration ], sampling_frequency = sampling_frequency , num_units = num_units )
320+
321+ all_spike_trains = sorting .get_unit_spike_trains (unit_ids = sorting .unit_ids , use_cache = use_cache )
322+ assert isinstance (all_spike_trains , dict )
323+ assert set (all_spike_trains .keys ()) == set (sorting .unit_ids )
324+ for unit_id in sorting .unit_ids :
325+ spiketrain = sorting .get_unit_spike_train (segment_index = 0 , unit_id = unit_id , use_cache = use_cache )
326+ assert np .array_equal (all_spike_trains [unit_id ], spiketrain )
327+
328+ # test with times
329+ spike_trains_times = sorting .get_unit_spike_trains_in_seconds (
330+ unit_ids = sorting .unit_ids , return_times = True , use_cache = use_cache
331+ )
332+ assert isinstance (spike_trains_times , dict )
333+ assert set (spike_trains_times .keys ()) == set (sorting .unit_ids )
334+ for unit_id in sorting .unit_ids :
335+ spiketrain = sorting .get_unit_spike_train (
336+ segment_index = 0 , unit_id = unit_id , use_cache = use_cache , return_times = True
337+ )
338+ spiketrain_times = sorting .get_unit_spike_train_in_seconds (
339+ segment_index = 0 , unit_id = unit_id , use_cache = use_cache
340+ )
341+ assert np .allclose (spiketrain_times , spiketrain )
342+
343+
313344if __name__ == "__main__" :
314345 import tempfile
315346
0 commit comments