|
28 | 28 |
|
29 | 29 | import spikeinterface.full as si |
30 | 30 | from spikeinterface.core.testing import check_sortings_equal |
31 | | -from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter |
| 31 | +from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter, read_kilosort4_motion |
| 32 | +from spikeinterface.core.motion import Motion |
32 | 33 | from probeinterface.io import write_prb |
33 | 34 | from spikeinterface.extractors import read_kilosort_as_analyzer |
34 | 35 |
|
@@ -669,6 +670,43 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): |
669 | 670 | assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) |
670 | 671 | assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) |
671 | 672 |
|
| 673 | + def test_read_kilosort4_motion(self, recording_and_paths, tmp_path): |
| 674 | + """ |
| 675 | + Test that read_kilosort4_motion returns a Motion object whose displacement |
| 676 | + equals dshift (not dshift + yblk), and that temporal/spatial bins are correct. |
| 677 | + """ |
| 678 | + recording, _ = recording_and_paths |
| 679 | + sorter_output_dir = tmp_path / "ks4_motion_output" / "sorter_output" |
| 680 | + |
| 681 | + si.run_sorter( |
| 682 | + "kilosort4", |
| 683 | + recording, |
| 684 | + folder=tmp_path / "ks4_motion_output", |
| 685 | + remove_existing_folder=True, |
| 686 | + ) |
| 687 | + |
| 688 | + ops = np.load(sorter_output_dir / "ops.npy", allow_pickle=True).item() |
| 689 | + yblk = ops["yblk"] |
| 690 | + dshift = ops["dshift"] |
| 691 | + |
| 692 | + # without recording: temporal bins estimated from batch count |
| 693 | + motion = read_kilosort4_motion(sorter_output_dir) |
| 694 | + assert isinstance(motion, Motion) |
| 695 | + assert motion.displacement[0].shape == dshift.shape |
| 696 | + np.testing.assert_array_equal(motion.displacement[0], dshift) |
| 697 | + np.testing.assert_array_equal(motion.spatial_bins_um, yblk) |
| 698 | + assert motion.temporal_bins_s[0].shape[0] == dshift.shape[0] |
| 699 | + # displacement must be relative (not offset by spatial bin position) |
| 700 | + assert not np.allclose(motion.displacement[0], dshift + yblk) |
| 701 | + |
| 702 | + # with recording: temporal bins bounded by recording times |
| 703 | + motion_rec = read_kilosort4_motion(sorter_output_dir, recording=recording) |
| 704 | + assert isinstance(motion_rec, Motion) |
| 705 | + np.testing.assert_array_equal(motion_rec.displacement[0], dshift) |
| 706 | + assert motion_rec.temporal_bins_s[0].shape[0] == dshift.shape[0] |
| 707 | + assert motion_rec.temporal_bins_s[0][0] >= recording.get_start_time() |
| 708 | + assert motion_rec.temporal_bins_s[0][-1] <= recording.get_end_time() |
| 709 | + |
672 | 710 | ##### Helpers ###### |
673 | 711 | def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): |
674 | 712 | """ |
|
0 commit comments