7070 "nearest_chans" : 8 ,
7171 "nearest_templates" : 35 ,
7272 "max_channel_distance" : 5 ,
73- "templates_from_data" : False ,
7473 "n_templates" : 10 ,
7574 "n_pcs" : 3 ,
7675 "Th_single_ch" : 4 ,
109108 # max_peels is not affecting the results in this short dataset
110109 PARAMETERS_NOT_AFFECTING_RESULTS .append ("max_peels" )
111110
111+ if parse (kilosort .__version__ ) >= parse ("4.0.33" ):
112+ PARAMS_TO_TEST_DICT .update ({"cluster_neighbors" : 11 })
113+ PARAMETERS_NOT_AFFECTING_RESULTS .append ("cluster_neighbors" )
114+
115+ if parse (kilosort .__version__ ) >= parse ("4.0.37" ):
116+ PARAMS_TO_TEST_DICT .update ({"max_cluster_subset" : 20 })
117+ PARAMETERS_NOT_AFFECTING_RESULTS .append ("max_cluster_subset" )
118+
112119
113120PARAMS_TO_TEST = list (PARAMS_TO_TEST_DICT .keys ())
114121
@@ -178,11 +185,11 @@ def _save_ground_truth_recording(self, recording, tmp_path):
178185 """
179186 paths = {
180187 "session_scope_tmp_path" : tmp_path ,
181- "recording_path" : tmp_path / "my_test_recording" ,
188+ "recording_path" : tmp_path / "my_test_recording" / "traces_cached_seg0.raw" ,
182189 "probe_path" : tmp_path / "my_test_probe.prb" ,
183190 }
184191
185- recording .save (folder = paths ["recording_path" ], overwrite = True )
192+ recording .save (folder = paths ["recording_path" ]. parent , overwrite = True )
186193
187194 probegroup = recording .get_probegroup ()
188195 write_prb (paths ["probe_path" ].as_posix (), probegroup )
@@ -214,7 +221,7 @@ def test_default_settings_all_represented(self):
214221 tested_keys += additional_non_tested_keys
215222
216223 for param_key in DEFAULT_SETTINGS :
217- if param_key not in ["n_chan_bin" , "fs" , "tmin" , "tmax" ]:
224+ if param_key not in ["n_chan_bin" , "fs" , "tmin" , "tmax" , "templates_from_data" ]:
218225 assert param_key in tested_keys , f"param: { param_key } in DEFAULT SETTINGS but not tested."
219226
220227 def test_spikeinterface_defaults_against_kilsort (self ):
@@ -234,8 +241,11 @@ def test_spikeinterface_defaults_against_kilsort(self):
234241
235242 # Testing Arguments ###
236243 def test_set_files_arguments (self ):
244+ expected_arguments = ["settings" , "filename" , "probe" , "probe_name" , "data_dir" , "results_dir" , "bad_channels" ]
245+ if parse (kilosort .__version__ ) >= parse ("4.0.34" ):
246+ expected_arguments += ["shank_idx" ]
237247 self ._check_arguments (
238- set_files , [ "settings" , "filename" , "probe" , "probe_name" , "data_dir" , "results_dir" , "bad_channels" ]
248+ set_files , expected_arguments
239249 )
240250
241251 def test_initialize_ops_arguments (self ):
@@ -248,6 +258,8 @@ def test_initialize_ops_arguments(self):
248258 "device" ,
249259 "save_preprocessed_copy" ,
250260 ]
261+ if parse (kilosort .__version__ ) >= parse ("4.0.37" ):
262+ expected_arguments += ["gui_mode" ]
251263
252264 self ._check_arguments (
253265 initialize_ops ,
@@ -258,22 +270,29 @@ def test_compute_preprocessing_arguments(self):
258270 self ._check_arguments (compute_preprocessing , ["ops" , "device" , "tic0" , "file_object" ])
259271
260272 def test_compute_drift_location_arguments (self ):
261- self ._check_arguments (
262- compute_drift_correction , ["ops" , "device" , "tic0" , "progress_bar" , "file_object" , "clear_cache" ]
263- )
273+ expected_arguments = ["ops" , "device" , "tic0" , "progress_bar" , "file_object" , "clear_cache" ]
274+ if parse (kilosort .__version__ ) >= parse ("4.0.28" ):
275+ expected_arguments += ["verbose" ]
276+ self ._check_arguments (compute_drift_correction , expected_arguments )
264277
265278 def test_detect_spikes_arguments (self ):
266- self ._check_arguments (detect_spikes , ["ops" , "device" , "bfile" , "tic0" , "progress_bar" , "clear_cache" ])
279+ expected_arguments = ["ops" , "device" , "bfile" , "tic0" , "progress_bar" , "clear_cache" ]
280+ if parse (kilosort .__version__ ) >= parse ("4.0.28" ):
281+ expected_arguments += ["verbose" ]
282+ self ._check_arguments (detect_spikes , expected_arguments )
267283
268284 def test_cluster_spikes_arguments (self ):
269- self ._check_arguments (
270- cluster_spikes , ["st" , "tF" , "ops" , "device" , "bfile" , "tic0" , "progress_bar" , "clear_cache" ]
271- )
285+ expected_arguments = ["st" , "tF" , "ops" , "device" , "bfile" , "tic0" , "progress_bar" , "clear_cache" ]
286+ if parse (kilosort .__version__ ) >= parse ("4.0.28" ):
287+ expected_arguments += ["verbose" ]
288+ self ._check_arguments (cluster_spikes , expected_arguments )
272289
273290 def test_save_sorting_arguments (self ):
274- expected_arguments = ["ops" , "results_dir" , "st" , "clu" , "tF" , "Wall" , "imin" , "tic0" , "save_extra_vars" ]
275-
276- expected_arguments .append ("save_preprocessed_copy" )
291+ expected_arguments = [
292+ "ops" , "results_dir" , "st" , "clu" , "tF" , "Wall" , "imin" , "tic0" , "save_extra_vars" , "save_preprocessed_copy"
293+ ]
294+ if parse (kilosort .__version__ ) >= parse ("4.0.39" ):
295+ expected_arguments .append ("skip_dat_path" )
277296
278297 self ._check_arguments (save_sorting , expected_arguments )
279298
@@ -528,33 +547,60 @@ def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, pa
528547 kilosort_output_dir = tmp_path / "kilosort_output_dir"
529548 spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir"
530549
531- def monkeypatch_filter_function (self , X , ops = None , ibatch = None ):
532- """
533- This is a direct copy of the kilosort io.BinaryFiltered.filter
534- function, with hp_filter and whitening matrix code sections, and
535- comments removed. This is the easiest way to monkeypatch (tried a few approaches)
536- """
537- if self .chan_map is not None :
538- X = X [self .chan_map ]
539-
540- if self .invert_sign :
541- X = X * - 1
542-
543- X = X - X .mean (1 ).unsqueeze (1 )
544- if self .do_CAR :
545- X = X - torch .median (X , 0 )[0 ]
546-
547- if self .hp_filter is not None :
548- pass
549-
550- if self .artifact_threshold < np .inf :
551- if torch .any (torch .abs (X ) >= self .artifact_threshold ):
552- return torch .zeros_like (X )
553-
554- if self .whiten_mat is not None :
555- pass
556- return X
557-
550+ if parse (kilosort .__version__ ) >= parse ("4.0.33" ):
551+ def monkeypatch_filter_function (self , X , ops = None , ibatch = None , skip_preproc = False ):
552+ """
553+ This is a direct copy of the kilosort io.BinaryFiltered.filter
554+ function, with hp_filter and whitening matrix code sections, and
555+ comments removed. This is the easiest way to monkeypatch (tried a few approaches)
556+ """
557+ if self .chan_map is not None :
558+ X = X [self .chan_map ]
559+
560+ if self .invert_sign :
561+ X = X * - 1
562+
563+ X = X - X .mean (1 ).unsqueeze (1 )
564+ if self .do_CAR :
565+ X = X - torch .median (X , 0 )[0 ]
566+
567+ if self .hp_filter is not None :
568+ pass
569+
570+ if self .artifact_threshold < np .inf :
571+ if torch .any (torch .abs (X ) >= self .artifact_threshold ):
572+ return torch .zeros_like (X )
573+
574+ if self .whiten_mat is not None :
575+ pass
576+ return X
577+ else :
578+ def monkeypatch_filter_function (self , X , ops = None , ibatch = None ):
579+ """
580+ This is a direct copy of the kilosort io.BinaryFiltered.filter
581+ function, with hp_filter and whitening matrix code sections, and
582+ comments removed. This is the easiest way to monkeypatch (tried a few approaches)
583+ """
584+ if self .chan_map is not None :
585+ X = X [self .chan_map ]
586+
587+ if self .invert_sign :
588+ X = X * - 1
589+
590+ X = X - X .mean (1 ).unsqueeze (1 )
591+ if self .do_CAR :
592+ X = X - torch .median (X , 0 )[0 ]
593+
594+ if self .hp_filter is not None :
595+ pass
596+
597+ if self .artifact_threshold < np .inf :
598+ if torch .any (torch .abs (X ) >= self .artifact_threshold ):
599+ return torch .zeros_like (X )
600+
601+ if self .whiten_mat is not None :
602+ pass
603+ return X
558604 monkeypatch .setattr ("kilosort.io.BinaryFiltered.filter" , monkeypatch_filter_function )
559605
560606 ks_settings , _ , ks_format_probe = self ._get_kilosort_native_settings (recording , paths , param_key , param_value )
@@ -615,7 +661,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value
615661 are through the function, these are split here.
616662 """
617663 settings = {
618- "data_dir " : paths ["recording_path" ],
664+ "filename " : paths ["recording_path" ],
619665 "n_chan_bin" : recording .get_num_channels (),
620666 "fs" : recording .get_sampling_frequency (),
621667 }
0 commit comments