Skip to content

Commit e1006bc

Browse files
committed
Add read_kilosort4_motion function
1 parent 7eb2251 commit e1006bc

2 files changed

Lines changed: 51 additions & 2 deletions

File tree

src/spikeinterface/sorters/external/kilosort4.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22
from packaging import version
33

4-
from spikeinterface.core import write_binary_recording
4+
from spikeinterface.core import write_binary_recording, Motion
55
from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs
66
from .kilosortbase import KilosortBase
77
from spikeinterface.sorters.basesorter import get_job_kwargs
@@ -484,3 +484,52 @@ def _setup_json_probe_map(cls, recording, sorter_output_folder):
484484
"n_chan": n_chan,
485485
}
486486
save_probe(probe, str(sorter_output_folder / "chanMap.json"))
487+
488+
# close logger
489+
for handler in logger.handlers.copy():
490+
logger.removeHandler(handler)
491+
handler.close()
492+
493+
494+
def read_kilosort4_motion(sorter_output_folder: str | Path, recording: BaseRecording | None = None) -> Motion:
495+
"""Reads the motion information from a Kilosort4 output folder and returns a Motion object.
496+
497+
Parameters
498+
----------
499+
sorter_output_folder: str or Path
500+
The path to the Kilosort4 output folder.
501+
recording: BaseRecording, optional
502+
The recording object. If provided, the temporal bins will be estimated based on the recording's
503+
start and end times. If not provided, the temporal bins will be estimated based on the number
504+
of batches in the ops file.
505+
506+
Returns
507+
-------
508+
Motion
509+
A Motion object containing the displacement, temporal bins, and spatial bins.
510+
511+
"""
512+
sorter_output_folder = Path(sorter_output_folder)
513+
ops_file = sorter_output_folder / "ops.npy"
514+
if not ops_file.is_file():
515+
raise FileNotFoundError("'ops.npy' file not found!")
516+
ops = np.load(ops_file, allow_pickle=True).item()
517+
yblk = ops.get("yblk")
518+
dshift = ops.get("dshift")
519+
if yblk is None or dshift is None:
520+
raise Exception("'yblk' and 'dshift' fields not found in ops file!")
521+
displacement = dshift + yblk
522+
spatial_bins_um = yblk
523+
# estimate temporal bins
524+
batch_size = ops["batch_size"]
525+
fs = ops["fs"]
526+
t_bin = batch_size / fs
527+
if recording is not None:
528+
t_start = recording.get_start_time()
529+
t_end = recording.get_end_time()
530+
temporal_bins_s = np.linspace(t_start + t_bin / 2, t_end - t_bin / 2)
531+
else:
532+
temporal_bins_s = np.arange(displacement.shape[0]) * t_bin + t_bin / 2
533+
534+
motion = Motion(displacement=displacement, temporal_bins_s=temporal_bins_s, spatial_bins_um=spatial_bins_um)
535+
return motion

src/spikeinterface/sorters/sorterlist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .external.kilosort2 import Kilosort2Sorter
77
from .external.kilosort2_5 import Kilosort2_5Sorter
88
from .external.kilosort3 import Kilosort3Sorter
9-
from .external.kilosort4 import Kilosort4Sorter
9+
from .external.kilosort4 import Kilosort4Sorter, read_kilosort4_motion
1010
from .external.pykilosort import PyKilosortSorter
1111
from .external.klusta import KlustaSorter
1212
from .external.mountainsort4 import Mountainsort4Sorter

0 commit comments

Comments
 (0)