Skip to content

Commit a28ea0e

Browse files
committed
conflicts
2 parents 2c50e4c + f976557 commit a28ea0e

448 files changed

Lines changed: 25784 additions & 5355 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/actions/build-test-environment/action.yml

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,7 @@ runs:
1515
shell: bash
1616
run: |
1717
pip install datalad-installer
18-
wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz
19-
mkdir /home/runner/work/installation
20-
mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/
21-
workdir=$(pwd)
22-
cd /home/runner/work/installation
23-
tar xvzf git-annex-standalone-amd64.tar.gz
24-
echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH
25-
cd $workdir
18+
datalad-installer --sudo ok git-annex --method datalad/packages
2619
git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency
2720
- name: Force installation of latest dev from key-packages when running dev (not release)
2821
run: |

.github/scripts/check_kilosort4_releases.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import os
2-
import re
32
from pathlib import Path
43
import requests
54
import json
65
from packaging.version import parse
7-
import spikeinterface
6+
87

98
def get_pypi_versions(package_name):
109
"""
@@ -16,8 +15,10 @@ def get_pypi_versions(package_name):
1615
response.raise_for_status()
1716
data = response.json()
1817
versions = list(sorted(data["releases"].keys()))
19-
# Filter out versions that are less than 4.0.16
20-
versions = [ver for ver in versions if parse(ver) >= parse("4.0.16")]
18+
# Filter out versions that are less than 4.0.16 and different from 4.0.26 and 4.0.27
19+
# (buggy - https://github.com/MouseLand/Kilosort/releases/tag/v4.0.26)
20+
versions = [ver for ver in versions if parse(ver) >= parse("4.0.16") and
21+
parse(ver) not in [parse("4.0.26"), parse("4.0.27")]]
2122
return versions
2223

2324

.github/scripts/determine_testing_environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
elif changed_file.name == "nwbextractors.py":
4949
extractors_changed = True # There are NWB tests that are not streaming
5050
stream_extractors_changed = True
51-
elif changed_file.name == "iblextractors.py":
51+
elif changed_file.name == "iblextractors.py" or changed_file.name == "test_iblextractors.py":
5252
stream_extractors_changed = True
5353
elif "core" in changed_file.parts:
5454
core_changed = True

.github/scripts/test_kilosort4_ci.py

Lines changed: 103 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import pytest
2222
import copy
23-
from typing import Any
23+
from packaging.version import parse
2424
from inspect import signature
2525

2626
import numpy as np
@@ -70,7 +70,6 @@
7070
"nearest_chans": 8,
7171
"nearest_templates": 35,
7272
"max_channel_distance": 5,
73-
"templates_from_data": False,
7473
"n_templates": 10,
7574
"n_pcs": 3,
7675
"Th_single_ch": 4,
@@ -84,8 +83,6 @@
8483
"duplicate_spike_ms": 0.3,
8584
}
8685

87-
PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys())
88-
8986
PARAMETERS_NOT_AFFECTING_RESULTS = [
9087
"artifact_threshold",
9188
"ccg_threshold",
@@ -95,13 +92,32 @@
9592
"duplicate_spike_ms", # this is because ground-truth spikes don't have violations
9693
]
9794

98-
# THIS IS A PLACEHOLDER FOR FUTURE PARAMS TO TEST
99-
# if parse(version("kilosort")) >= parse("4.0.X"):
100-
# PARAMS_TO_TEST_DICT.update(
101-
# [
102-
# {"new_param": new_value},
103-
# ]
104-
# )
95+
96+
# Add/Remove version specific parameters
97+
if parse(kilosort.__version__) >= parse("4.0.22"):
98+
PARAMS_TO_TEST_DICT.update(
99+
{"position_limit": 50}
100+
)
101+
# Position limit only affects computing spike locations after sorting
102+
PARAMETERS_NOT_AFFECTING_RESULTS.append("position_limit")
103+
104+
if parse(kilosort.__version__) >= parse("4.0.24"):
105+
PARAMS_TO_TEST_DICT.update(
106+
{"max_peels": 200},
107+
)
108+
# max_peels is not affecting the results in this short dataset
109+
PARAMETERS_NOT_AFFECTING_RESULTS.append("max_peels")
110+
111+
if parse(kilosort.__version__) >= parse("4.0.33"):
112+
PARAMS_TO_TEST_DICT.update({"cluster_neighbors": 11})
113+
PARAMETERS_NOT_AFFECTING_RESULTS.append("cluster_neighbors")
114+
115+
if parse(kilosort.__version__) >= parse("4.0.37"):
116+
PARAMS_TO_TEST_DICT.update({"max_cluster_subset": 20})
117+
PARAMETERS_NOT_AFFECTING_RESULTS.append("max_cluster_subset")
118+
119+
120+
PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys())
105121

106122

107123
class TestKilosort4Long:
@@ -169,11 +185,11 @@ def _save_ground_truth_recording(self, recording, tmp_path):
169185
"""
170186
paths = {
171187
"session_scope_tmp_path": tmp_path,
172-
"recording_path": tmp_path / "my_test_recording",
188+
"recording_path": tmp_path / "my_test_recording" / "traces_cached_seg0.raw",
173189
"probe_path": tmp_path / "my_test_probe.prb",
174190
}
175191

176-
recording.save(folder=paths["recording_path"], overwrite=True)
192+
recording.save(folder=paths["recording_path"].parent, overwrite=True)
177193

178194
probegroup = recording.get_probegroup()
179195
write_prb(paths["probe_path"].as_posix(), probegroup)
@@ -205,7 +221,7 @@ def test_default_settings_all_represented(self):
205221
tested_keys += additional_non_tested_keys
206222

207223
for param_key in DEFAULT_SETTINGS:
208-
if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]:
224+
if param_key not in ["n_chan_bin", "fs", "tmin", "tmax", "templates_from_data"]:
209225
assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested."
210226

211227
def test_spikeinterface_defaults_against_kilsort(self):
@@ -225,8 +241,11 @@ def test_spikeinterface_defaults_against_kilsort(self):
225241

226242
# Testing Arguments ###
227243
def test_set_files_arguments(self):
244+
expected_arguments = ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"]
245+
if parse(kilosort.__version__) >= parse("4.0.34"):
246+
expected_arguments += ["shank_idx"]
228247
self._check_arguments(
229-
set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"]
248+
set_files, expected_arguments
230249
)
231250

232251
def test_initialize_ops_arguments(self):
@@ -239,6 +258,8 @@ def test_initialize_ops_arguments(self):
239258
"device",
240259
"save_preprocessed_copy",
241260
]
261+
if parse(kilosort.__version__) >= parse("4.0.37"):
262+
expected_arguments += ["gui_mode"]
242263

243264
self._check_arguments(
244265
initialize_ops,
@@ -249,17 +270,22 @@ def test_compute_preprocessing_arguments(self):
249270
self._check_arguments(compute_preprocessing, ["ops", "device", "tic0", "file_object"])
250271

251272
def test_compute_drift_location_arguments(self):
252-
self._check_arguments(
253-
compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"]
254-
)
273+
expected_arguments = ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"]
274+
if parse(kilosort.__version__) >= parse("4.0.28"):
275+
expected_arguments += ["verbose"]
276+
self._check_arguments(compute_drift_correction, expected_arguments)
255277

256278
def test_detect_spikes_arguments(self):
257-
self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"])
279+
expected_arguments = ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]
280+
if parse(kilosort.__version__) >= parse("4.0.28"):
281+
expected_arguments += ["verbose"]
282+
self._check_arguments(detect_spikes, expected_arguments)
258283

259284
def test_cluster_spikes_arguments(self):
260-
self._check_arguments(
261-
cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]
262-
)
285+
expected_arguments = ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]
286+
if parse(kilosort.__version__) >= parse("4.0.28"):
287+
expected_arguments += ["verbose"]
288+
self._check_arguments(cluster_spikes, expected_arguments)
263289

264290
def test_save_sorting_arguments(self):
265291
expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"]
@@ -519,33 +545,60 @@ def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, pa
519545
kilosort_output_dir = tmp_path / "kilosort_output_dir"
520546
spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir"
521547

522-
def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
523-
"""
524-
This is a direct copy of the kilosort io.BinaryFiltered.filter
525-
function, with hp_filter and whitening matrix code sections, and
526-
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
527-
"""
528-
if self.chan_map is not None:
529-
X = X[self.chan_map]
530-
531-
if self.invert_sign:
532-
X = X * -1
533-
534-
X = X - X.mean(1).unsqueeze(1)
535-
if self.do_CAR:
536-
X = X - torch.median(X, 0)[0]
537-
538-
if self.hp_filter is not None:
539-
pass
540-
541-
if self.artifact_threshold < np.inf:
542-
if torch.any(torch.abs(X) >= self.artifact_threshold):
543-
return torch.zeros_like(X)
544-
545-
if self.whiten_mat is not None:
546-
pass
547-
return X
548-
548+
if parse(kilosort.__version__) >= parse("4.0.33"):
549+
def monkeypatch_filter_function(self, X, ops=None, ibatch=None, skip_preproc=False):
550+
"""
551+
This is a direct copy of the kilosort io.BinaryFiltered.filter
552+
function, with hp_filter and whitening matrix code sections, and
553+
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
554+
"""
555+
if self.chan_map is not None:
556+
X = X[self.chan_map]
557+
558+
if self.invert_sign:
559+
X = X * -1
560+
561+
X = X - X.mean(1).unsqueeze(1)
562+
if self.do_CAR:
563+
X = X - torch.median(X, 0)[0]
564+
565+
if self.hp_filter is not None:
566+
pass
567+
568+
if self.artifact_threshold < np.inf:
569+
if torch.any(torch.abs(X) >= self.artifact_threshold):
570+
return torch.zeros_like(X)
571+
572+
if self.whiten_mat is not None:
573+
pass
574+
return X
575+
else:
576+
def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
577+
"""
578+
This is a direct copy of the kilosort io.BinaryFiltered.filter
579+
function, with hp_filter and whitening matrix code sections, and
580+
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
581+
"""
582+
if self.chan_map is not None:
583+
X = X[self.chan_map]
584+
585+
if self.invert_sign:
586+
X = X * -1
587+
588+
X = X - X.mean(1).unsqueeze(1)
589+
if self.do_CAR:
590+
X = X - torch.median(X, 0)[0]
591+
592+
if self.hp_filter is not None:
593+
pass
594+
595+
if self.artifact_threshold < np.inf:
596+
if torch.any(torch.abs(X) >= self.artifact_threshold):
597+
return torch.zeros_like(X)
598+
599+
if self.whiten_mat is not None:
600+
pass
601+
return X
549602
monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", monkeypatch_filter_function)
550603

551604
ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value)
@@ -606,7 +659,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value
606659
are through the function, these are split here.
607660
"""
608661
settings = {
609-
"data_dir": paths["recording_path"],
662+
"filename": paths["recording_path"],
610663
"n_chan_bin": recording.get_num_channels(),
611664
"fs": recording.get_sampling_frequency(),
612665
}

0 commit comments

Comments
 (0)