2020
2121import pytest
2222import copy
23- from typing import Any
23+ from packaging . version import parse
2424from inspect import signature
2525
2626import numpy as np
7070 "nearest_chans" : 8 ,
7171 "nearest_templates" : 35 ,
7272 "max_channel_distance" : 5 ,
73- "templates_from_data" : False ,
7473 "n_templates" : 10 ,
7574 "n_pcs" : 3 ,
7675 "Th_single_ch" : 4 ,
8483 "duplicate_spike_ms" : 0.3 ,
8584}
8685
87- PARAMS_TO_TEST = list (PARAMS_TO_TEST_DICT .keys ())
88-
8986PARAMETERS_NOT_AFFECTING_RESULTS = [
9087 "artifact_threshold" ,
9188 "ccg_threshold" ,
9592 "duplicate_spike_ms" , # this is because ground-truth spikes don't have violations
9693]
9794
98- # THIS IS A PLACEHOLDER FOR FUTURE PARAMS TO TEST
99- # if parse(version("kilosort")) >= parse("4.0.X"):
100- # PARAMS_TO_TEST_DICT.update(
101- # [
102- # {"new_param": new_value},
103- # ]
104- # )
95+
96+ # Add/Remove version specific parameters
97+ if parse (kilosort .__version__ ) >= parse ("4.0.22" ):
98+ PARAMS_TO_TEST_DICT .update (
99+ {"position_limit" : 50 }
100+ )
101+ # Position limit only affects computing spike locations after sorting
102+ PARAMETERS_NOT_AFFECTING_RESULTS .append ("position_limit" )
103+
104+ if parse (kilosort .__version__ ) >= parse ("4.0.24" ):
105+ PARAMS_TO_TEST_DICT .update (
106+ {"max_peels" : 200 },
107+ )
108+ # max_peels is not affecting the results in this short dataset
109+ PARAMETERS_NOT_AFFECTING_RESULTS .append ("max_peels" )
110+
111+ if parse (kilosort .__version__ ) >= parse ("4.0.33" ):
112+ PARAMS_TO_TEST_DICT .update ({"cluster_neighbors" : 11 })
113+ PARAMETERS_NOT_AFFECTING_RESULTS .append ("cluster_neighbors" )
114+
115+ if parse (kilosort .__version__ ) >= parse ("4.0.37" ):
116+ PARAMS_TO_TEST_DICT .update ({"max_cluster_subset" : 20 })
117+ PARAMETERS_NOT_AFFECTING_RESULTS .append ("max_cluster_subset" )
118+
119+
120+ PARAMS_TO_TEST = list (PARAMS_TO_TEST_DICT .keys ())
105121
106122
107123class TestKilosort4Long :
@@ -169,11 +185,11 @@ def _save_ground_truth_recording(self, recording, tmp_path):
169185 """
170186 paths = {
171187 "session_scope_tmp_path" : tmp_path ,
172- "recording_path" : tmp_path / "my_test_recording" ,
188+ "recording_path" : tmp_path / "my_test_recording" / "traces_cached_seg0.raw" ,
173189 "probe_path" : tmp_path / "my_test_probe.prb" ,
174190 }
175191
176- recording .save (folder = paths ["recording_path" ], overwrite = True )
192+ recording .save (folder = paths ["recording_path" ]. parent , overwrite = True )
177193
178194 probegroup = recording .get_probegroup ()
179195 write_prb (paths ["probe_path" ].as_posix (), probegroup )
@@ -205,7 +221,7 @@ def test_default_settings_all_represented(self):
205221 tested_keys += additional_non_tested_keys
206222
207223 for param_key in DEFAULT_SETTINGS :
208- if param_key not in ["n_chan_bin" , "fs" , "tmin" , "tmax" ]:
224+ if param_key not in ["n_chan_bin" , "fs" , "tmin" , "tmax" , "templates_from_data" ]:
209225 assert param_key in tested_keys , f"param: { param_key } in DEFAULT SETTINGS but not tested."
210226
211227 def test_spikeinterface_defaults_against_kilsort (self ):
@@ -225,8 +241,11 @@ def test_spikeinterface_defaults_against_kilsort(self):
225241
226242 # Testing Arguments ###
227243 def test_set_files_arguments (self ):
244+ expected_arguments = ["settings" , "filename" , "probe" , "probe_name" , "data_dir" , "results_dir" , "bad_channels" ]
245+ if parse (kilosort .__version__ ) >= parse ("4.0.34" ):
246+ expected_arguments += ["shank_idx" ]
228247 self ._check_arguments (
229- set_files , [ "settings" , "filename" , "probe" , "probe_name" , "data_dir" , "results_dir" , "bad_channels" ]
248+ set_files , expected_arguments
230249 )
231250
232251 def test_initialize_ops_arguments (self ):
@@ -239,6 +258,8 @@ def test_initialize_ops_arguments(self):
239258 "device" ,
240259 "save_preprocessed_copy" ,
241260 ]
261+ if parse (kilosort .__version__ ) >= parse ("4.0.37" ):
262+ expected_arguments += ["gui_mode" ]
242263
243264 self ._check_arguments (
244265 initialize_ops ,
@@ -249,17 +270,22 @@ def test_compute_preprocessing_arguments(self):
249270 self ._check_arguments (compute_preprocessing , ["ops" , "device" , "tic0" , "file_object" ])
250271
251272 def test_compute_drift_location_arguments (self ):
252- self ._check_arguments (
253- compute_drift_correction , ["ops" , "device" , "tic0" , "progress_bar" , "file_object" , "clear_cache" ]
254- )
273+ expected_arguments = ["ops" , "device" , "tic0" , "progress_bar" , "file_object" , "clear_cache" ]
274+ if parse (kilosort .__version__ ) >= parse ("4.0.28" ):
275+ expected_arguments += ["verbose" ]
276+ self ._check_arguments (compute_drift_correction , expected_arguments )
255277
256278 def test_detect_spikes_arguments (self ):
257- self ._check_arguments (detect_spikes , ["ops" , "device" , "bfile" , "tic0" , "progress_bar" , "clear_cache" ])
279+ expected_arguments = ["ops" , "device" , "bfile" , "tic0" , "progress_bar" , "clear_cache" ]
280+ if parse (kilosort .__version__ ) >= parse ("4.0.28" ):
281+ expected_arguments += ["verbose" ]
282+ self ._check_arguments (detect_spikes , expected_arguments )
258283
259284 def test_cluster_spikes_arguments (self ):
260- self ._check_arguments (
261- cluster_spikes , ["st" , "tF" , "ops" , "device" , "bfile" , "tic0" , "progress_bar" , "clear_cache" ]
262- )
285+ expected_arguments = ["st" , "tF" , "ops" , "device" , "bfile" , "tic0" , "progress_bar" , "clear_cache" ]
286+ if parse (kilosort .__version__ ) >= parse ("4.0.28" ):
287+ expected_arguments += ["verbose" ]
288+ self ._check_arguments (cluster_spikes , expected_arguments )
263289
264290 def test_save_sorting_arguments (self ):
265291 expected_arguments = ["ops" , "results_dir" , "st" , "clu" , "tF" , "Wall" , "imin" , "tic0" , "save_extra_vars" ]
@@ -519,33 +545,60 @@ def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, pa
519545 kilosort_output_dir = tmp_path / "kilosort_output_dir"
520546 spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir"
521547
522- def monkeypatch_filter_function (self , X , ops = None , ibatch = None ):
523- """
524- This is a direct copy of the kilosort io.BinaryFiltered.filter
525- function, with hp_filter and whitening matrix code sections, and
526- comments removed. This is the easiest way to monkeypatch (tried a few approaches)
527- """
528- if self .chan_map is not None :
529- X = X [self .chan_map ]
530-
531- if self .invert_sign :
532- X = X * - 1
533-
534- X = X - X .mean (1 ).unsqueeze (1 )
535- if self .do_CAR :
536- X = X - torch .median (X , 0 )[0 ]
537-
538- if self .hp_filter is not None :
539- pass
540-
541- if self .artifact_threshold < np .inf :
542- if torch .any (torch .abs (X ) >= self .artifact_threshold ):
543- return torch .zeros_like (X )
544-
545- if self .whiten_mat is not None :
546- pass
547- return X
548-
548+ if parse (kilosort .__version__ ) >= parse ("4.0.33" ):
549+ def monkeypatch_filter_function (self , X , ops = None , ibatch = None , skip_preproc = False ):
550+ """
551+ This is a direct copy of the kilosort io.BinaryFiltered.filter
552+ function, with hp_filter and whitening matrix code sections, and
553+ comments removed. This is the easiest way to monkeypatch (tried a few approaches)
554+ """
555+ if self .chan_map is not None :
556+ X = X [self .chan_map ]
557+
558+ if self .invert_sign :
559+ X = X * - 1
560+
561+ X = X - X .mean (1 ).unsqueeze (1 )
562+ if self .do_CAR :
563+ X = X - torch .median (X , 0 )[0 ]
564+
565+ if self .hp_filter is not None :
566+ pass
567+
568+ if self .artifact_threshold < np .inf :
569+ if torch .any (torch .abs (X ) >= self .artifact_threshold ):
570+ return torch .zeros_like (X )
571+
572+ if self .whiten_mat is not None :
573+ pass
574+ return X
575+ else :
576+ def monkeypatch_filter_function (self , X , ops = None , ibatch = None ):
577+ """
578+ This is a direct copy of the kilosort io.BinaryFiltered.filter
579+ function, with hp_filter and whitening matrix code sections, and
580+ comments removed. This is the easiest way to monkeypatch (tried a few approaches)
581+ """
582+ if self .chan_map is not None :
583+ X = X [self .chan_map ]
584+
585+ if self .invert_sign :
586+ X = X * - 1
587+
588+ X = X - X .mean (1 ).unsqueeze (1 )
589+ if self .do_CAR :
590+ X = X - torch .median (X , 0 )[0 ]
591+
592+ if self .hp_filter is not None :
593+ pass
594+
595+ if self .artifact_threshold < np .inf :
596+ if torch .any (torch .abs (X ) >= self .artifact_threshold ):
597+ return torch .zeros_like (X )
598+
599+ if self .whiten_mat is not None :
600+ pass
601+ return X
549602 monkeypatch .setattr ("kilosort.io.BinaryFiltered.filter" , monkeypatch_filter_function )
550603
551604 ks_settings , _ , ks_format_probe = self ._get_kilosort_native_settings (recording , paths , param_key , param_value )
@@ -606,7 +659,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value
606659 are through the function, these are split here.
607660 """
608661 settings = {
609- "data_dir " : paths ["recording_path" ],
662+ "filename " : paths ["recording_path" ],
610663 "n_chan_bin" : recording .get_num_channels (),
611664 "fs" : recording .get_sampling_frequency (),
612665 }
0 commit comments