Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 66 additions & 33 deletions .github/scripts/test_kilosort4_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
"nearest_chans": 8,
"nearest_templates": 35,
"max_channel_distance": 5,
"templates_from_data": False,
"n_templates": 10,
"n_pcs": 3,
"Th_single_ch": 4,
Expand Down Expand Up @@ -109,6 +108,10 @@
# max_peels is not affecting the results in this short dataset
PARAMETERS_NOT_AFFECTING_RESULTS.append("max_peels")

if parse(kilosort.__version__) >= parse("4.0.33"):
PARAMS_TO_TEST_DICT.update({"cluster_neighbors": 11})
PARAMETERS_NOT_AFFECTING_RESULTS.append("cluster_neighbors")


PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys())

Expand Down Expand Up @@ -178,11 +181,11 @@ def _save_ground_truth_recording(self, recording, tmp_path):
"""
paths = {
"session_scope_tmp_path": tmp_path,
"recording_path": tmp_path / "my_test_recording",
"recording_path": tmp_path / "my_test_recording" / "traces_cached_seg0.raw",
"probe_path": tmp_path / "my_test_probe.prb",
}

recording.save(folder=paths["recording_path"], overwrite=True)
recording.save(folder=paths["recording_path"].parent, overwrite=True)

probegroup = recording.get_probegroup()
write_prb(paths["probe_path"].as_posix(), probegroup)
Expand Down Expand Up @@ -214,7 +217,7 @@ def test_default_settings_all_represented(self):
tested_keys += additional_non_tested_keys

for param_key in DEFAULT_SETTINGS:
if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]:
if param_key not in ["n_chan_bin", "fs", "tmin", "tmax", "templates_from_data"]:
assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested."

def test_spikeinterface_defaults_against_kilsort(self):
Expand All @@ -234,8 +237,11 @@ def test_spikeinterface_defaults_against_kilsort(self):

# Testing Arguments ###
def test_set_files_arguments(self):
expected_arguments = ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"]
if parse(kilosort.__version__) >= parse("4.0.34"):
expected_arguments += ["shank_idx"]
self._check_arguments(
set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"]
set_files, expected_arguments
)

def test_initialize_ops_arguments(self):
Expand Down Expand Up @@ -533,33 +539,60 @@ def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, pa
kilosort_output_dir = tmp_path / "kilosort_output_dir"
spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir"

def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
"""
This is a direct copy of the kilosort io.BinaryFiltered.filter
function, with hp_filter and whitening matrix code sections, and
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
"""
if self.chan_map is not None:
X = X[self.chan_map]

if self.invert_sign:
X = X * -1

X = X - X.mean(1).unsqueeze(1)
if self.do_CAR:
X = X - torch.median(X, 0)[0]

if self.hp_filter is not None:
pass

if self.artifact_threshold < np.inf:
if torch.any(torch.abs(X) >= self.artifact_threshold):
return torch.zeros_like(X)

if self.whiten_mat is not None:
pass
return X

if parse(kilosort.__version__) >= parse("4.0.33"):
def monkeypatch_filter_function(self, X, ops=None, ibatch=None, skip_preproc=False):
"""
This is a direct copy of the kilosort io.BinaryFiltered.filter
function, with hp_filter and whitening matrix code sections, and
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
"""
if self.chan_map is not None:
X = X[self.chan_map]

if self.invert_sign:
X = X * -1

X = X - X.mean(1).unsqueeze(1)
if self.do_CAR:
X = X - torch.median(X, 0)[0]

if self.hp_filter is not None:
pass

if self.artifact_threshold < np.inf:
if torch.any(torch.abs(X) >= self.artifact_threshold):
return torch.zeros_like(X)

if self.whiten_mat is not None:
pass
return X
else:
def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
"""
This is a direct copy of the kilosort io.BinaryFiltered.filter
function, with hp_filter and whitening matrix code sections, and
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
"""
if self.chan_map is not None:
X = X[self.chan_map]

if self.invert_sign:
X = X * -1

X = X - X.mean(1).unsqueeze(1)
if self.do_CAR:
X = X - torch.median(X, 0)[0]

if self.hp_filter is not None:
pass

if self.artifact_threshold < np.inf:
if torch.any(torch.abs(X) >= self.artifact_threshold):
return torch.zeros_like(X)

if self.whiten_mat is not None:
pass
return X
Copy link
Copy Markdown
Member

@zm711 zm711 May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is brutal..... Sorry. didn't highlight what I wanted it to. I meant the fact you had to duplicate the function because an used argument change.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully this doesn't happen too soon! Anyways, kudos to @JoeZiminski since the tests to check the arguments of different KS functions are very useful to understand what's changed!

monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", monkeypatch_filter_function)

ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value)
Expand Down Expand Up @@ -620,7 +653,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value
are through the function, these are split here.
"""
settings = {
"data_dir": paths["recording_path"],
"filename": paths["recording_path"],
"n_chan_bin": recording.get_num_channels(),
"fs": recording.get_sampling_frequency(),
}
Expand Down
17 changes: 13 additions & 4 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
settings_ks["n_chan_bin"] = recording.get_num_channels()
settings_ks["fs"] = recording.sampling_frequency
if not do_CAR:
print("Skipping common average reference.")
if verbose:
print("Skipping common average reference.")

tic0 = time.time()

Expand All @@ -252,7 +253,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
bad_channels = params["bad_channels"]
clear_cache = params["clear_cache"]

filename, data_dir, results_dir, probe = set_files(
set_files_kwargs = dict(
settings=settings,
filename=filename,
probe=probe,
Expand All @@ -261,6 +262,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
results_dir=results_dir,
bad_channels=bad_channels,
)
if version.parse(ks_version) >= version.parse("4.0.34"):
set_files_kwargs.update(dict(shank_idx=None))

filename, data_dir, results_dir, probe = set_files(**set_files_kwargs)

ops = initialize_ops(
settings=settings,
Expand All @@ -271,6 +276,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
device=device,
save_preprocessed_copy=save_preprocessed_copy,
)
if version.parse(ks_version) >= version.parse("4.0.34"):
ops = ops[0]

n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = (
get_run_parameters(ops)
Expand All @@ -280,7 +287,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if not params["skip_kilosort_preprocessing"]:
ops = compute_preprocessing(ops=ops, device=device, tic0=tic0, file_object=file_object)
else:
print("Skipping kilosort preprocessing.")
if verbose:
print("Skipping kilosort preprocessing.")
bfile = BinaryFiltered(
filename=ops["filename"],
n_chan_bin=n_chan_bin,
Expand Down Expand Up @@ -309,7 +317,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
torch.random.manual_seed(1)

if not params["do_correction"]:
print("Skipping drift correction.")
if verbose:
print("Skipping drift correction.")
ops["nblocks"] = 0

drift_kwargs = dict(
Expand Down