Skip to content

Commit 0ea163f

Browse files
committed
fix: read function displacement, add test for read_kilosort4_motion function
1 parent b26ab8f commit 0ea163f

3 files changed

Lines changed: 45 additions & 6 deletions

File tree

.github/scripts/test_kilosort4_ci.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
import spikeinterface.full as si
3030
from spikeinterface.core.testing import check_sortings_equal
31-
from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter
31+
from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter, read_kilosort4_motion
32+
from spikeinterface.core.motion import Motion
3233
from probeinterface.io import write_prb
3334
from spikeinterface.extractors import read_kilosort_as_analyzer
3435

@@ -669,6 +670,43 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
669670
assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1)
670671
assert np.array_equal(results["ks"]["clus"], results["si"]["clus"])
671672

673+
def test_read_kilosort4_motion(self, recording_and_paths, tmp_path):
674+
"""
675+
Test that read_kilosort4_motion returns a Motion object whose displacement
676+
equals dshift (not dshift + yblk), and that temporal/spatial bins are correct.
677+
"""
678+
recording, _ = recording_and_paths
679+
sorter_output_dir = tmp_path / "ks4_motion_output" / "sorter_output"
680+
681+
si.run_sorter(
682+
"kilosort4",
683+
recording,
684+
folder=tmp_path / "ks4_motion_output",
685+
remove_existing_folder=True,
686+
)
687+
688+
ops = np.load(sorter_output_dir / "ops.npy", allow_pickle=True).item()
689+
yblk = ops["yblk"]
690+
dshift = ops["dshift"]
691+
692+
# without recording: temporal bins estimated from batch count
693+
motion = read_kilosort4_motion(sorter_output_dir)
694+
assert isinstance(motion, Motion)
695+
assert motion.displacement[0].shape == dshift.shape
696+
np.testing.assert_array_equal(motion.displacement[0], dshift)
697+
np.testing.assert_array_equal(motion.spatial_bins_um, yblk)
698+
assert motion.temporal_bins_s[0].shape[0] == dshift.shape[0]
699+
# displacement must be relative (not offset by spatial bin position)
700+
assert not np.allclose(motion.displacement[0], dshift + yblk)
701+
702+
# with recording: temporal bins bounded by recording times
703+
motion_rec = read_kilosort4_motion(sorter_output_dir, recording=recording)
704+
assert isinstance(motion_rec, Motion)
705+
np.testing.assert_array_equal(motion_rec.displacement[0], dshift)
706+
assert motion_rec.temporal_bins_s[0].shape[0] == dshift.shape[0]
707+
assert motion_rec.temporal_bins_s[0][0] >= recording.get_start_time()
708+
assert motion_rec.temporal_bins_s[0][-1] <= recording.get_end_time()
709+
672710
##### Helpers ######
673711
def _get_kilosort_native_settings(self, recording, paths, param_key, param_value):
674712
"""

src/spikeinterface/core/motion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ class Motion:
1414
Parameters
1515
----------
1616
displacement : numpy array 2d or list of
17-
Motion estimate in um.
17+
Motion estimate in um, relative to the spatial_bins_um.
18+
The first dimension is temporal bins, the second dimension is spatial bins.
1819
List is the number of segment.
1920
For each semgent :
2021
@@ -93,6 +94,7 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde
9394
Parameters
9495
----------
9596
times_s: np.array
97+
Times at which to evaluate the motion, in seconds. This should be a one-dimensional array.
9698
locations_um: np.array
9799
Either this is a one-dimensional array (a vector of positions along self.dimension), or
98100
else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1.

src/spikeinterface/sorters/external/kilosort4.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33
from packaging import version
44

5+
import numpy as np
56

67
from spikeinterface.core import write_binary_recording, Motion, BaseRecording
78
from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs
@@ -171,7 +172,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
171172

172173
import time
173174
import torch
174-
import numpy as np
175175
import logging
176176

177177
if version.parse(cls.get_sorter_version()) < version.parse("4.0.16"):
@@ -468,7 +468,6 @@ def _get_result_from_folder(cls, sorter_output_folder):
468468
def _setup_json_probe_map(cls, recording, sorter_output_folder):
469469
"""Create a JSON probe map file for Kilosort4."""
470470
from kilosort.io import save_probe
471-
import numpy as np
472471

473472
groups = recording.get_channel_groups()
474473
positions = np.array(recording.get_channel_locations())
@@ -520,7 +519,7 @@ def read_kilosort4_motion(sorter_output_folder: str | Path, recording: BaseRecor
520519
dshift = ops.get("dshift")
521520
if yblk is None or dshift is None:
522521
raise Exception("'yblk' and 'dshift' fields not found in ops file!")
523-
displacement = dshift + yblk
522+
displacement = dshift
524523
spatial_bins_um = yblk
525524
# estimate temporal bins
526525
batch_size = ops["batch_size"]
@@ -529,7 +528,7 @@ def read_kilosort4_motion(sorter_output_folder: str | Path, recording: BaseRecor
529528
if recording is not None:
530529
t_start = recording.get_start_time()
531530
t_end = recording.get_end_time()
532-
temporal_bins_s = np.linspace(t_start + t_bin / 2, t_end - t_bin / 2)
531+
temporal_bins_s = np.linspace(t_start + t_bin / 2, t_end - t_bin / 2, displacement.shape[0])
533532
else:
534533
temporal_bins_s = np.arange(displacement.shape[0]) * t_bin + t_bin / 2
535534

0 commit comments

Comments
 (0)