Skip to content

Commit 8caa69f

Browse files
committed
Fix filter monkeypatch
1 parent 0944fa0 commit 8caa69f

1 file changed

Lines changed: 57 additions & 29 deletions

File tree

.github/scripts/test_kilosort4_ci.py

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@
108108
# max_peels is not affecting the results in this short dataset
109109
PARAMETERS_NOT_AFFECTING_RESULTS.append("max_peels")
110110

111-
if parse(kilosort.__version__) >= parse("4.0.34"):
112-
PARAMS_TO_TEST_DICT.update({"cluster_neighbors": 20})
111+
if parse(kilosort.__version__) >= parse("4.0.33"):
112+
PARAMS_TO_TEST_DICT.update({"cluster_neighbors": 11})
113+
PARAMETERS_NOT_AFFECTING_RESULTS.append("position_limit")
113114

114115

115116
PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys())
@@ -538,33 +539,60 @@ def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, pa
538539
kilosort_output_dir = tmp_path / "kilosort_output_dir"
539540
spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir"
540541

541-
def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
542-
"""
543-
This is a direct copy of the kilosort io.BinaryFiltered.filter
544-
function, with hp_filter and whitening matrix code sections, and
545-
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
546-
"""
547-
if self.chan_map is not None:
548-
X = X[self.chan_map]
549-
550-
if self.invert_sign:
551-
X = X * -1
552-
553-
X = X - X.mean(1).unsqueeze(1)
554-
if self.do_CAR:
555-
X = X - torch.median(X, 0)[0]
556-
557-
if self.hp_filter is not None:
558-
pass
559-
560-
if self.artifact_threshold < np.inf:
561-
if torch.any(torch.abs(X) >= self.artifact_threshold):
562-
return torch.zeros_like(X)
563-
564-
if self.whiten_mat is not None:
565-
pass
566-
return X
567-
542+
if parse(kilosort.__version__) >= parse("4.0.33"):
543+
def monkeypatch_filter_function(self, X, ops=None, ibatch=None, skip_preproc=False):
544+
"""
545+
This is a direct copy of the kilosort io.BinaryFiltered.filter
546+
function, with hp_filter and whitening matrix code sections, and
547+
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
548+
"""
549+
if self.chan_map is not None:
550+
X = X[self.chan_map]
551+
552+
if self.invert_sign:
553+
X = X * -1
554+
555+
X = X - X.mean(1).unsqueeze(1)
556+
if self.do_CAR:
557+
X = X - torch.median(X, 0)[0]
558+
559+
if self.hp_filter is not None:
560+
pass
561+
562+
if self.artifact_threshold < np.inf:
563+
if torch.any(torch.abs(X) >= self.artifact_threshold):
564+
return torch.zeros_like(X)
565+
566+
if self.whiten_mat is not None:
567+
pass
568+
return X
569+
else:
570+
def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
571+
"""
572+
This is a direct copy of the kilosort io.BinaryFiltered.filter
573+
function, with hp_filter and whitening matrix code sections, and
574+
comments removed. This is the easiest way to monkeypatch (tried a few approaches)
575+
"""
576+
if self.chan_map is not None:
577+
X = X[self.chan_map]
578+
579+
if self.invert_sign:
580+
X = X * -1
581+
582+
X = X - X.mean(1).unsqueeze(1)
583+
if self.do_CAR:
584+
X = X - torch.median(X, 0)[0]
585+
586+
if self.hp_filter is not None:
587+
pass
588+
589+
if self.artifact_threshold < np.inf:
590+
if torch.any(torch.abs(X) >= self.artifact_threshold):
591+
return torch.zeros_like(X)
592+
593+
if self.whiten_mat is not None:
594+
pass
595+
return X
568596
monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", monkeypatch_filter_function)
569597

570598
ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value)

0 commit comments

Comments
 (0)