7070 "nearest_chans" : 8 ,
7171 "nearest_templates" : 35 ,
7272 "max_channel_distance" : 5 ,
73- "templates_from_data" : False ,
7473 "n_templates" : 10 ,
7574 "n_pcs" : 3 ,
7675 "Th_single_ch" : 4 ,
109108 # max_peels is not affecting the results in this short dataset
110109 PARAMETERS_NOT_AFFECTING_RESULTS .append ("max_peels" )
111110
111+ if parse (kilosort .__version__ ) >= parse ("4.0.33" ):
112+ PARAMS_TO_TEST_DICT .update ({"cluster_neighbors" : 11 })
113+ PARAMETERS_NOT_AFFECTING_RESULTS .append ("cluster_neighbors" )
114+
115+ if parse (kilosort .__version__ ) >= parse ("4.0.37" ):
116+ PARAMS_TO_TEST_DICT .update ({"max_cluster_subset" : 20 })
117+ PARAMETERS_NOT_AFFECTING_RESULTS .append ("max_cluster_subset" )
118+
112119
113120PARAMS_TO_TEST = list (PARAMS_TO_TEST_DICT .keys ())
114121
@@ -178,11 +185,11 @@ def _save_ground_truth_recording(self, recording, tmp_path):
178185 """
179186 paths = {
180187 "session_scope_tmp_path" : tmp_path ,
181- "recording_path" : tmp_path / "my_test_recording" ,
188+ "recording_path" : tmp_path / "my_test_recording" / "traces_cached_seg0.raw" ,
182189 "probe_path" : tmp_path / "my_test_probe.prb" ,
183190 }
184191
185- recording .save (folder = paths ["recording_path" ], overwrite = True )
192+ recording .save (folder = paths ["recording_path" ]. parent , overwrite = True )
186193
187194 probegroup = recording .get_probegroup ()
188195 write_prb (paths ["probe_path" ].as_posix (), probegroup )
@@ -214,7 +221,7 @@ def test_default_settings_all_represented(self):
214221 tested_keys += additional_non_tested_keys
215222
216223 for param_key in DEFAULT_SETTINGS :
217- if param_key not in ["n_chan_bin" , "fs" , "tmin" , "tmax" ]:
224+ if param_key not in ["n_chan_bin" , "fs" , "tmin" , "tmax" , "templates_from_data" ]:
218225 assert param_key in tested_keys , f"param: { param_key } in DEFAULT SETTINGS but not tested."
219226
220227 def test_spikeinterface_defaults_against_kilsort (self ):
@@ -234,8 +241,11 @@ def test_spikeinterface_defaults_against_kilsort(self):
234241
235242 # Testing Arguments ###
236243 def test_set_files_arguments (self ):
244+ expected_arguments = ["settings" , "filename" , "probe" , "probe_name" , "data_dir" , "results_dir" , "bad_channels" ]
245+ if parse (kilosort .__version__ ) >= parse ("4.0.34" ):
246+ expected_arguments += ["shank_idx" ]
237247 self ._check_arguments (
238- set_files , [ "settings" , "filename" , "probe" , "probe_name" , "data_dir" , "results_dir" , "bad_channels" ]
248+ set_files , expected_arguments
239249 )
240250
241251 def test_initialize_ops_arguments (self ):
@@ -248,6 +258,8 @@ def test_initialize_ops_arguments(self):
248258 "device" ,
249259 "save_preprocessed_copy" ,
250260 ]
261+ if parse (kilosort .__version__ ) >= parse ("4.0.37" ):
262+ expected_arguments += ["gui_mode" ]
251263
252264 self ._check_arguments (
253265 initialize_ops ,
@@ -533,33 +545,60 @@ def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, pa
533545 kilosort_output_dir = tmp_path / "kilosort_output_dir"
534546 spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir"
535547
536- def monkeypatch_filter_function (self , X , ops = None , ibatch = None ):
537- """
538- This is a direct copy of the kilosort io.BinaryFiltered.filter
539- function, with hp_filter and whitening matrix code sections, and
540- comments removed. This is the easiest way to monkeypatch (tried a few approaches)
541- """
542- if self .chan_map is not None :
543- X = X [self .chan_map ]
544-
545- if self .invert_sign :
546- X = X * - 1
547-
548- X = X - X .mean (1 ).unsqueeze (1 )
549- if self .do_CAR :
550- X = X - torch .median (X , 0 )[0 ]
551-
552- if self .hp_filter is not None :
553- pass
554-
555- if self .artifact_threshold < np .inf :
556- if torch .any (torch .abs (X ) >= self .artifact_threshold ):
557- return torch .zeros_like (X )
558-
559- if self .whiten_mat is not None :
560- pass
561- return X
562-
548+ if parse (kilosort .__version__ ) >= parse ("4.0.33" ):
549+ def monkeypatch_filter_function (self , X , ops = None , ibatch = None , skip_preproc = False ):
550+ """
551+ This is a direct copy of the kilosort io.BinaryFiltered.filter
552+ function, with hp_filter and whitening matrix code sections, and
553+ comments removed. This is the easiest way to monkeypatch (tried a few approaches)
554+ """
555+ if self .chan_map is not None :
556+ X = X [self .chan_map ]
557+
558+ if self .invert_sign :
559+ X = X * - 1
560+
561+ X = X - X .mean (1 ).unsqueeze (1 )
562+ if self .do_CAR :
563+ X = X - torch .median (X , 0 )[0 ]
564+
565+ if self .hp_filter is not None :
566+ pass
567+
568+ if self .artifact_threshold < np .inf :
569+ if torch .any (torch .abs (X ) >= self .artifact_threshold ):
570+ return torch .zeros_like (X )
571+
572+ if self .whiten_mat is not None :
573+ pass
574+ return X
575+ else :
576+ def monkeypatch_filter_function (self , X , ops = None , ibatch = None ):
577+ """
578+ This is a direct copy of the kilosort io.BinaryFiltered.filter
579+ function, with hp_filter and whitening matrix code sections, and
580+ comments removed. This is the easiest way to monkeypatch (tried a few approaches)
581+ """
582+ if self .chan_map is not None :
583+ X = X [self .chan_map ]
584+
585+ if self .invert_sign :
586+ X = X * - 1
587+
588+ X = X - X .mean (1 ).unsqueeze (1 )
589+ if self .do_CAR :
590+ X = X - torch .median (X , 0 )[0 ]
591+
592+ if self .hp_filter is not None :
593+ pass
594+
595+ if self .artifact_threshold < np .inf :
596+ if torch .any (torch .abs (X ) >= self .artifact_threshold ):
597+ return torch .zeros_like (X )
598+
599+ if self .whiten_mat is not None :
600+ pass
601+ return X
563602 monkeypatch .setattr ("kilosort.io.BinaryFiltered.filter" , monkeypatch_filter_function )
564603
565604 ks_settings , _ , ks_format_probe = self ._get_kilosort_native_settings (recording , paths , param_key , param_value )
@@ -620,7 +659,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value
620659 are through the function, these are split here.
621660 """
622661 settings = {
623- "data_dir " : paths ["recording_path" ],
662+ "filename " : paths ["recording_path" ],
624663 "n_chan_bin" : recording .get_num_channels (),
625664 "fs" : recording .get_sampling_frequency (),
626665 }
0 commit comments