diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 16e3c1ec7d..0d245e3783 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -70,7 +70,6 @@ "nearest_chans": 8, "nearest_templates": 35, "max_channel_distance": 5, - "templates_from_data": False, "n_templates": 10, "n_pcs": 3, "Th_single_ch": 4, @@ -109,6 +108,10 @@ # max_peels is not affecting the results in this short dataset PARAMETERS_NOT_AFFECTING_RESULTS.append("max_peels") +if parse(kilosort.__version__) >= parse("4.0.33"): + PARAMS_TO_TEST_DICT.update({"cluster_neighbors": 11}) + PARAMETERS_NOT_AFFECTING_RESULTS.append("cluster_neighbors") + PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys()) @@ -178,11 +181,11 @@ def _save_ground_truth_recording(self, recording, tmp_path): """ paths = { "session_scope_tmp_path": tmp_path, - "recording_path": tmp_path / "my_test_recording", + "recording_path": tmp_path / "my_test_recording" / "traces_cached_seg0.raw", "probe_path": tmp_path / "my_test_probe.prb", } - recording.save(folder=paths["recording_path"], overwrite=True) + recording.save(folder=paths["recording_path"].parent, overwrite=True) probegroup = recording.get_probegroup() write_prb(paths["probe_path"].as_posix(), probegroup) @@ -214,7 +217,7 @@ def test_default_settings_all_represented(self): tested_keys += additional_non_tested_keys for param_key in DEFAULT_SETTINGS: - if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: + if param_key not in ["n_chan_bin", "fs", "tmin", "tmax", "templates_from_data"]: assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." def test_spikeinterface_defaults_against_kilsort(self): @@ -234,8 +237,11 @@ def test_spikeinterface_defaults_against_kilsort(self): # Testing Arguments ### def test_set_files_arguments(self): + expected_arguments = ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"] + if parse(kilosort.__version__) >= parse("4.0.34"): + expected_arguments += ["shank_idx"] self._check_arguments( - set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"] + set_files, expected_arguments ) def test_initialize_ops_arguments(self): @@ -533,33 +539,60 @@ def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, pa kilosort_output_dir = tmp_path / "kilosort_output_dir" spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" - def monkeypatch_filter_function(self, X, ops=None, ibatch=None): - """ - This is a direct copy of the kilosort io.BinaryFiltered.filter - function, with hp_filter and whitening matrix code sections, and - comments removed. This is the easiest way to monkeypatch (tried a few approaches) - """ - if self.chan_map is not None: - X = X[self.chan_map] - - if self.invert_sign: - X = X * -1 - - X = X - X.mean(1).unsqueeze(1) - if self.do_CAR: - X = X - torch.median(X, 0)[0] - - if self.hp_filter is not None: - pass - - if self.artifact_threshold < np.inf: - if torch.any(torch.abs(X) >= self.artifact_threshold): - return torch.zeros_like(X) - - if self.whiten_mat is not None: - pass - return X - + if parse(kilosort.__version__) >= parse("4.0.33"): + def monkeypatch_filter_function(self, X, ops=None, ibatch=None, skip_preproc=False): + """ + This is a direct copy of the kilosort io.BinaryFiltered.filter + function, with hp_filter and whitening matrix code sections, and + comments removed. This is the easiest way to monkeypatch (tried a few approaches) + """ + if self.chan_map is not None: + X = X[self.chan_map] + + if self.invert_sign: + X = X * -1 + + X = X - X.mean(1).unsqueeze(1) + if self.do_CAR: + X = X - torch.median(X, 0)[0] + + if self.hp_filter is not None: + pass + + if self.artifact_threshold < np.inf: + if torch.any(torch.abs(X) >= self.artifact_threshold): + return torch.zeros_like(X) + + if self.whiten_mat is not None: + pass + return X + else: + def monkeypatch_filter_function(self, X, ops=None, ibatch=None): + """ + This is a direct copy of the kilosort io.BinaryFiltered.filter + function, with hp_filter and whitening matrix code sections, and + comments removed. This is the easiest way to monkeypatch (tried a few approaches) + """ + if self.chan_map is not None: + X = X[self.chan_map] + + if self.invert_sign: + X = X * -1 + + X = X - X.mean(1).unsqueeze(1) + if self.do_CAR: + X = X - torch.median(X, 0)[0] + + if self.hp_filter is not None: + pass + + if self.artifact_threshold < np.inf: + if torch.any(torch.abs(X) >= self.artifact_threshold): + return torch.zeros_like(X) + + if self.whiten_mat is not None: + pass + return X monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", monkeypatch_filter_function) ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) @@ -620,7 +653,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value are through the function, these are split here. """ settings = { - "data_dir": paths["recording_path"], + "filename": paths["recording_path"], "n_chan_bin": recording.get_num_channels(), "fs": recording.get_sampling_frequency(), } diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index c61f2663a5..68a0be4317 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -235,7 +235,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): settings_ks["n_chan_bin"] = recording.get_num_channels() settings_ks["fs"] = recording.sampling_frequency if not do_CAR: - print("Skipping common average reference.") + if verbose: + print("Skipping common average reference.") tic0 = time.time() @@ -252,7 +253,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): bad_channels = params["bad_channels"] clear_cache = params["clear_cache"] - filename, data_dir, results_dir, probe = set_files( + set_files_kwargs = dict( settings=settings, filename=filename, probe=probe, @@ -261,6 +262,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): results_dir=results_dir, bad_channels=bad_channels, ) + if version.parse(ks_version) >= version.parse("4.0.34"): + set_files_kwargs.update(dict(shank_idx=None)) + + filename, data_dir, results_dir, probe = set_files(**set_files_kwargs) ops = initialize_ops( settings=settings, @@ -271,6 +276,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): device=device, save_preprocessed_copy=save_preprocessed_copy, ) + if version.parse(ks_version) >= version.parse("4.0.34"): + ops = ops[0] n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( get_run_parameters(ops) @@ -280,7 +287,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if not params["skip_kilosort_preprocessing"]: ops = compute_preprocessing(ops=ops, device=device, tic0=tic0, file_object=file_object) else: - print("Skipping kilosort preprocessing.") + if verbose: + print("Skipping kilosort preprocessing.") bfile = BinaryFiltered( filename=ops["filename"], n_chan_bin=n_chan_bin, @@ -309,7 +317,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): torch.random.manual_seed(1) if not params["do_correction"]: - print("Skipping drift correction.") + if verbose: + print("Skipping drift correction.") ops["nblocks"] = 0 drift_kwargs = dict(