Skip to content

Commit 4e90730

Browse files
authored
Merge pull request #3900 from alejoe91/new-ks-versions
Fix ks4 tests and support ks4>=4.0.34
2 parents d142fe4 + eb951e0 commit 4e90730

2 files changed

Lines changed: 79 additions & 37 deletions

File tree

.github/scripts/test_kilosort4_ci.py

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
"nearest_chans": 8,
7171
"nearest_templates": 35,
7272
"max_channel_distance": 5,
73-
"templates_from_data": False,
7473
"n_templates": 10,
7574
"n_pcs": 3,
7675
"Th_single_ch": 4,
@@ -109,6 +108,10 @@
109108
# max_peels is not affecting the results in this short dataset
110109
PARAMETERS_NOT_AFFECTING_RESULTS.append("max_peels")
111110

111+
if parse(kilosort.__version__) >= parse("4.0.33"):
112+
PARAMS_TO_TEST_DICT.update({"cluster_neighbors": 11})
113+
PARAMETERS_NOT_AFFECTING_RESULTS.append("cluster_neighbors")
114+
112115

113116
PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys())
114117

@@ -178,11 +181,11 @@ def _save_ground_truth_recording(self, recording, tmp_path):
178181
"""
179182
paths = {
180183
"session_scope_tmp_path": tmp_path,
181-
"recording_path": tmp_path / "my_test_recording",
184+
"recording_path": tmp_path / "my_test_recording" / "traces_cached_seg0.raw",
182185
"probe_path": tmp_path / "my_test_probe.prb",
183186
}
184187

185-
recording.save(folder=paths["recording_path"], overwrite=True)
188+
recording.save(folder=paths["recording_path"].parent, overwrite=True)
186189

187190
probegroup = recording.get_probegroup()
188191
write_prb(paths["probe_path"].as_posix(), probegroup)
@@ -214,7 +217,7 @@ def test_default_settings_all_represented(self):
214217
tested_keys += additional_non_tested_keys
215218

216219
for param_key in DEFAULT_SETTINGS:
217-
if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]:
220+
if param_key not in ["n_chan_bin", "fs", "tmin", "tmax", "templates_from_data"]:
218221
assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested."
219222

220223
def test_spikeinterface_defaults_against_kilsort(self):
@@ -234,8 +237,11 @@ def test_spikeinterface_defaults_against_kilsort(self):
234237

235238
# Testing Arguments ###
236239
def test_set_files_arguments(self):
240+
expected_arguments = ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"]
241+
if parse(kilosort.__version__) >= parse("4.0.34"):
242+
expected_arguments += ["shank_idx"]
237243
self._check_arguments(
238-
set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"]
244+
set_files, expected_arguments
239245
)
240246

241247
def test_initialize_ops_arguments(self):
@@ -533,33 +539,60 @@ def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, pa
533539
kilosort_output_dir = tmp_path / "kilosort_output_dir"
534540
spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir"
535541

536-
def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
537-
"""
538-
This is a direct copy of the kilosort io.BinaryFiltered.filter
539-
function, with hp_filter and whitening matrix code sections, and
540-
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
541-
"""
542-
if self.chan_map is not None:
543-
X = X[self.chan_map]
544-
545-
if self.invert_sign:
546-
X = X * -1
547-
548-
X = X - X.mean(1).unsqueeze(1)
549-
if self.do_CAR:
550-
X = X - torch.median(X, 0)[0]
551-
552-
if self.hp_filter is not None:
553-
pass
554-
555-
if self.artifact_threshold < np.inf:
556-
if torch.any(torch.abs(X) >= self.artifact_threshold):
557-
return torch.zeros_like(X)
558-
559-
if self.whiten_mat is not None:
560-
pass
561-
return X
562-
542+
if parse(kilosort.__version__) >= parse("4.0.33"):
543+
def monkeypatch_filter_function(self, X, ops=None, ibatch=None, skip_preproc=False):
544+
"""
545+
This is a direct copy of the kilosort io.BinaryFiltered.filter
546+
function, with hp_filter and whitening matrix code sections, and
547+
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
548+
"""
549+
if self.chan_map is not None:
550+
X = X[self.chan_map]
551+
552+
if self.invert_sign:
553+
X = X * -1
554+
555+
X = X - X.mean(1).unsqueeze(1)
556+
if self.do_CAR:
557+
X = X - torch.median(X, 0)[0]
558+
559+
if self.hp_filter is not None:
560+
pass
561+
562+
if self.artifact_threshold < np.inf:
563+
if torch.any(torch.abs(X) >= self.artifact_threshold):
564+
return torch.zeros_like(X)
565+
566+
if self.whiten_mat is not None:
567+
pass
568+
return X
569+
else:
570+
def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
571+
"""
572+
This is a direct copy of the kilosort io.BinaryFiltered.filter
573+
function, with hp_filter and whitening matrix code sections, and
574+
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
575+
"""
576+
if self.chan_map is not None:
577+
X = X[self.chan_map]
578+
579+
if self.invert_sign:
580+
X = X * -1
581+
582+
X = X - X.mean(1).unsqueeze(1)
583+
if self.do_CAR:
584+
X = X - torch.median(X, 0)[0]
585+
586+
if self.hp_filter is not None:
587+
pass
588+
589+
if self.artifact_threshold < np.inf:
590+
if torch.any(torch.abs(X) >= self.artifact_threshold):
591+
return torch.zeros_like(X)
592+
593+
if self.whiten_mat is not None:
594+
pass
595+
return X
563596
monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", monkeypatch_filter_function)
564597

565598
ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value)
@@ -620,7 +653,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value
620653
are through the function, these are split here.
621654
"""
622655
settings = {
623-
"data_dir": paths["recording_path"],
656+
"filename": paths["recording_path"],
624657
"n_chan_bin": recording.get_num_channels(),
625658
"fs": recording.get_sampling_frequency(),
626659
}

src/spikeinterface/sorters/external/kilosort4.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
235235
settings_ks["n_chan_bin"] = recording.get_num_channels()
236236
settings_ks["fs"] = recording.sampling_frequency
237237
if not do_CAR:
238-
print("Skipping common average reference.")
238+
if verbose:
239+
print("Skipping common average reference.")
239240

240241
tic0 = time.time()
241242

@@ -252,7 +253,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
252253
bad_channels = params["bad_channels"]
253254
clear_cache = params["clear_cache"]
254255

255-
filename, data_dir, results_dir, probe = set_files(
256+
set_files_kwargs = dict(
256257
settings=settings,
257258
filename=filename,
258259
probe=probe,
@@ -261,6 +262,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
261262
results_dir=results_dir,
262263
bad_channels=bad_channels,
263264
)
265+
if version.parse(ks_version) >= version.parse("4.0.34"):
266+
set_files_kwargs.update(dict(shank_idx=None))
267+
268+
filename, data_dir, results_dir, probe = set_files(**set_files_kwargs)
264269

265270
ops = initialize_ops(
266271
settings=settings,
@@ -271,6 +276,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
271276
device=device,
272277
save_preprocessed_copy=save_preprocessed_copy,
273278
)
279+
if version.parse(ks_version) >= version.parse("4.0.34"):
280+
ops = ops[0]
274281

275282
n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = (
276283
get_run_parameters(ops)
@@ -280,7 +287,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
280287
if not params["skip_kilosort_preprocessing"]:
281288
ops = compute_preprocessing(ops=ops, device=device, tic0=tic0, file_object=file_object)
282289
else:
283-
print("Skipping kilosort preprocessing.")
290+
if verbose:
291+
print("Skipping kilosort preprocessing.")
284292
bfile = BinaryFiltered(
285293
filename=ops["filename"],
286294
n_chan_bin=n_chan_bin,
@@ -309,7 +317,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
309317
torch.random.manual_seed(1)
310318

311319
if not params["do_correction"]:
312-
print("Skipping drift correction.")
320+
if verbose:
321+
print("Skipping drift correction.")
313322
ops["nblocks"] = 0
314323

315324
drift_kwargs = dict(

0 commit comments

Comments
 (0)