Skip to content

Commit 1b8aa4b

Browse files
committed
Solve conflicts
2 parents c252b9c + c398228 commit 1b8aa4b

460 files changed

Lines changed: 24488 additions & 6072 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/actions/build-test-environment/action.yml

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,7 @@ runs:
1515
shell: bash
1616
run: |
1717
pip install datalad-installer
18-
wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz
19-
mkdir /home/runner/work/installation
20-
mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/
21-
workdir=$(pwd)
22-
cd /home/runner/work/installation
23-
tar xvzf git-annex-standalone-amd64.tar.gz
24-
echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH
25-
cd $workdir
18+
datalad-installer --sudo ok git-annex --method datalad/packages
2619
git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency
2720
- name: Force installation of latest dev from key-packages when running dev (not release)
2821
run: |

.github/scripts/check_kilosort4_releases.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import os
2-
import re
32
from pathlib import Path
43
import requests
54
import json
65
from packaging.version import parse
7-
import spikeinterface
6+
87

98
def get_pypi_versions(package_name):
109
"""
@@ -16,8 +15,10 @@ def get_pypi_versions(package_name):
1615
response.raise_for_status()
1716
data = response.json()
1817
versions = list(sorted(data["releases"].keys()))
19-
# Filter out versions that are less than 4.0.16
20-
versions = [ver for ver in versions if parse(ver) >= parse("4.0.16")]
18+
# Filter out versions that are less than 4.0.16 and different from 4.0.26 and 4.0.27
19+
# (buggy - https://github.com/MouseLand/Kilosort/releases/tag/v4.0.26)
20+
versions = [ver for ver in versions if parse(ver) >= parse("4.0.16") and
21+
parse(ver) not in [parse("4.0.26"), parse("4.0.27")]]
2122
return versions
2223

2324

.github/scripts/determine_testing_environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
elif changed_file.name == "nwbextractors.py":
4949
extractors_changed = True # There are NWB tests that are not streaming
5050
stream_extractors_changed = True
51-
elif changed_file.name == "iblextractors.py":
51+
elif changed_file.name == "iblextractors.py" or changed_file.name == "test_iblextractors.py":
5252
stream_extractors_changed = True
5353
elif "core" in changed_file.parts:
5454
core_changed = True

.github/scripts/test_kilosort4_ci.py

Lines changed: 89 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
"nearest_chans": 8,
7171
"nearest_templates": 35,
7272
"max_channel_distance": 5,
73-
"templates_from_data": False,
7473
"n_templates": 10,
7574
"n_pcs": 3,
7675
"Th_single_ch": 4,
@@ -109,6 +108,14 @@
109108
# max_peels is not affecting the results in this short dataset
110109
PARAMETERS_NOT_AFFECTING_RESULTS.append("max_peels")
111110

111+
if parse(kilosort.__version__) >= parse("4.0.33"):
112+
PARAMS_TO_TEST_DICT.update({"cluster_neighbors": 11})
113+
PARAMETERS_NOT_AFFECTING_RESULTS.append("cluster_neighbors")
114+
115+
if parse(kilosort.__version__) >= parse("4.0.37"):
116+
PARAMS_TO_TEST_DICT.update({"max_cluster_subset": 20})
117+
PARAMETERS_NOT_AFFECTING_RESULTS.append("max_cluster_subset")
118+
112119

113120
PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys())
114121

@@ -178,11 +185,11 @@ def _save_ground_truth_recording(self, recording, tmp_path):
178185
"""
179186
paths = {
180187
"session_scope_tmp_path": tmp_path,
181-
"recording_path": tmp_path / "my_test_recording",
188+
"recording_path": tmp_path / "my_test_recording" / "traces_cached_seg0.raw",
182189
"probe_path": tmp_path / "my_test_probe.prb",
183190
}
184191

185-
recording.save(folder=paths["recording_path"], overwrite=True)
192+
recording.save(folder=paths["recording_path"].parent, overwrite=True)
186193

187194
probegroup = recording.get_probegroup()
188195
write_prb(paths["probe_path"].as_posix(), probegroup)
@@ -214,7 +221,7 @@ def test_default_settings_all_represented(self):
214221
tested_keys += additional_non_tested_keys
215222

216223
for param_key in DEFAULT_SETTINGS:
217-
if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]:
224+
if param_key not in ["n_chan_bin", "fs", "tmin", "tmax", "templates_from_data"]:
218225
assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested."
219226

220227
def test_spikeinterface_defaults_against_kilsort(self):
@@ -234,8 +241,11 @@ def test_spikeinterface_defaults_against_kilsort(self):
234241

235242
# Testing Arguments ###
236243
def test_set_files_arguments(self):
244+
expected_arguments = ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"]
245+
if parse(kilosort.__version__) >= parse("4.0.34"):
246+
expected_arguments += ["shank_idx"]
237247
self._check_arguments(
238-
set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"]
248+
set_files, expected_arguments
239249
)
240250

241251
def test_initialize_ops_arguments(self):
@@ -248,6 +258,8 @@ def test_initialize_ops_arguments(self):
248258
"device",
249259
"save_preprocessed_copy",
250260
]
261+
if parse(kilosort.__version__) >= parse("4.0.37"):
262+
expected_arguments += ["gui_mode"]
251263

252264
self._check_arguments(
253265
initialize_ops,
@@ -258,22 +270,29 @@ def test_compute_preprocessing_arguments(self):
258270
self._check_arguments(compute_preprocessing, ["ops", "device", "tic0", "file_object"])
259271

260272
def test_compute_drift_location_arguments(self):
261-
self._check_arguments(
262-
compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"]
263-
)
273+
expected_arguments = ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"]
274+
if parse(kilosort.__version__) >= parse("4.0.28"):
275+
expected_arguments += ["verbose"]
276+
self._check_arguments(compute_drift_correction, expected_arguments)
264277

265278
def test_detect_spikes_arguments(self):
266-
self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"])
279+
expected_arguments = ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]
280+
if parse(kilosort.__version__) >= parse("4.0.28"):
281+
expected_arguments += ["verbose"]
282+
self._check_arguments(detect_spikes, expected_arguments)
267283

268284
def test_cluster_spikes_arguments(self):
269-
self._check_arguments(
270-
cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]
271-
)
285+
expected_arguments = ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]
286+
if parse(kilosort.__version__) >= parse("4.0.28"):
287+
expected_arguments += ["verbose"]
288+
self._check_arguments(cluster_spikes, expected_arguments)
272289

273290
def test_save_sorting_arguments(self):
274-
expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"]
275-
276-
expected_arguments.append("save_preprocessed_copy")
291+
expected_arguments = [
292+
"ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars", "save_preprocessed_copy"
293+
]
294+
if parse(kilosort.__version__) >= parse("4.0.39"):
295+
expected_arguments.append("skip_dat_path")
277296

278297
self._check_arguments(save_sorting, expected_arguments)
279298

@@ -528,33 +547,60 @@ def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, pa
528547
kilosort_output_dir = tmp_path / "kilosort_output_dir"
529548
spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir"
530549

531-
def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
532-
"""
533-
This is a direct copy of the kilosort io.BinaryFiltered.filter
534-
function, with hp_filter and whitening matrix code sections, and
535-
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
536-
"""
537-
if self.chan_map is not None:
538-
X = X[self.chan_map]
539-
540-
if self.invert_sign:
541-
X = X * -1
542-
543-
X = X - X.mean(1).unsqueeze(1)
544-
if self.do_CAR:
545-
X = X - torch.median(X, 0)[0]
546-
547-
if self.hp_filter is not None:
548-
pass
549-
550-
if self.artifact_threshold < np.inf:
551-
if torch.any(torch.abs(X) >= self.artifact_threshold):
552-
return torch.zeros_like(X)
553-
554-
if self.whiten_mat is not None:
555-
pass
556-
return X
557-
550+
if parse(kilosort.__version__) >= parse("4.0.33"):
551+
def monkeypatch_filter_function(self, X, ops=None, ibatch=None, skip_preproc=False):
552+
"""
553+
This is a direct copy of the kilosort io.BinaryFiltered.filter
554+
function, with hp_filter and whitening matrix code sections, and
555+
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
556+
"""
557+
if self.chan_map is not None:
558+
X = X[self.chan_map]
559+
560+
if self.invert_sign:
561+
X = X * -1
562+
563+
X = X - X.mean(1).unsqueeze(1)
564+
if self.do_CAR:
565+
X = X - torch.median(X, 0)[0]
566+
567+
if self.hp_filter is not None:
568+
pass
569+
570+
if self.artifact_threshold < np.inf:
571+
if torch.any(torch.abs(X) >= self.artifact_threshold):
572+
return torch.zeros_like(X)
573+
574+
if self.whiten_mat is not None:
575+
pass
576+
return X
577+
else:
578+
def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
579+
"""
580+
This is a direct copy of the kilosort io.BinaryFiltered.filter
581+
function, with hp_filter and whitening matrix code sections, and
582+
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
583+
"""
584+
if self.chan_map is not None:
585+
X = X[self.chan_map]
586+
587+
if self.invert_sign:
588+
X = X * -1
589+
590+
X = X - X.mean(1).unsqueeze(1)
591+
if self.do_CAR:
592+
X = X - torch.median(X, 0)[0]
593+
594+
if self.hp_filter is not None:
595+
pass
596+
597+
if self.artifact_threshold < np.inf:
598+
if torch.any(torch.abs(X) >= self.artifact_threshold):
599+
return torch.zeros_like(X)
600+
601+
if self.whiten_mat is not None:
602+
pass
603+
return X
558604
monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", monkeypatch_filter_function)
559605

560606
ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value)
@@ -615,7 +661,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value
615661
are through the function, these are split here.
616662
"""
617663
settings = {
618-
"data_dir": paths["recording_path"],
664+
"filename": paths["recording_path"],
619665
"n_chan_bin": recording.get_num_channels(),
620666
"fs": recording.get_sampling_frequency(),
621667
}

0 commit comments

Comments
 (0)