|
1 | 1 | import warnings |
2 | 2 | from packaging import version |
3 | 3 |
|
4 | | -from spikeinterface.core import write_binary_recording |
| 4 | +from spikeinterface.core import write_binary_recording, Motion |
5 | 5 | from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs |
6 | 6 | from .kilosortbase import KilosortBase |
7 | 7 | from spikeinterface.sorters.basesorter import get_job_kwargs |
@@ -484,3 +484,52 @@ def _setup_json_probe_map(cls, recording, sorter_output_folder): |
484 | 484 | "n_chan": n_chan, |
485 | 485 | } |
486 | 486 | save_probe(probe, str(sorter_output_folder / "chanMap.json")) |
| 487 | + |
| 488 | + # close logger |
| 489 | + for handler in logger.handlers.copy(): |
| 490 | + logger.removeHandler(handler) |
| 491 | + handler.close() |
| 492 | + |
| 493 | + |
| 494 | +def read_kilosort4_motion(sorter_output_folder: str | Path, recording: BaseRecording | None = None) -> Motion: |
| 495 | + """Reads the motion information from a Kilosort4 output folder and returns a Motion object. |
| 496 | +
|
| 497 | + Parameters |
| 498 | + ---------- |
| 499 | + sorter_output_folder: str or Path |
| 500 | + The path to the Kilosort4 output folder. |
| 501 | + recording: BaseRecording, optional |
| 502 | + The recording object. If provided, the temporal bins will be estimated based on the recording's |
| 503 | + start and end times. If not provided, the temporal bins will be estimated based on the number |
| 504 | + of batches in the ops file. |
| 505 | +
|
| 506 | + Returns |
| 507 | + ------- |
| 508 | + Motion |
| 509 | + A Motion object containing the displacement, temporal bins, and spatial bins. |
| 510 | +
|
| 511 | + """ |
| 512 | + sorter_output_folder = Path(sorter_output_folder) |
| 513 | + ops_file = sorter_output_folder / "ops.npy" |
| 514 | + if not ops_file.is_file(): |
| 515 | + raise FileNotFoundError("'ops.npy' file not found!") |
| 516 | + ops = np.load(ops_file, allow_pickle=True).item() |
| 517 | + yblk = ops.get("yblk") |
| 518 | + dshift = ops.get("dshift") |
| 519 | + if yblk is None or dshift is None: |
| 520 | + raise Exception("'yblk' and 'dshift' fields not found in ops file!") |
| 521 | + displacement = dshift + yblk |
| 522 | + spatial_bins_um = yblk |
| 523 | + # estimate temporal bins |
| 524 | + batch_size = ops["batch_size"] |
| 525 | + fs = ops["fs"] |
| 526 | + t_bin = batch_size / fs |
| 527 | + if recording is not None: |
| 528 | + t_start = recording.get_start_time() |
| 529 | + t_end = recording.get_end_time() |
| 530 | + temporal_bins_s = np.linspace(t_start + t_bin / 2, t_end - t_bin / 2) |
| 531 | + else: |
| 532 | + temporal_bins_s = np.arange(displacement.shape[0]) * t_bin + t_bin / 2 |
| 533 | + |
| 534 | + motion = Motion(displacement=displacement, temporal_bins_s=temporal_bins_s, spatial_bins_um=spatial_bins_um) |
| 535 | + return motion |
0 commit comments