@@ -169,6 +169,11 @@ def _handle_backward_compatibility_on_load(self):
169169 if "waveform_ratios" not in self .params ["metric_names" ]:
170170 self .params ["metric_names" ].append ("waveform_ratios" )
171171
172+ # If original analyzer doesn't have "peaks_data" or "main_channel_templates",
173+ # then we can't save this tmp data (important for merges/splits)
174+ if "peaks_data" not in self .data :
175+ self .tmp_data_to_save = []
176+
172177 def _set_params (
173178 self ,
174179 metric_names : list [str ] | None = None ,
@@ -234,12 +239,12 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
234239 unit_ids = sorting_analyzer .unit_ids
235240 peak_sign = self .params ["peak_sign" ]
236241 upsampling_factor = self .params ["upsampling_factor" ]
237- min_thresh_detect_peaks_troughs = self .params [ "min_thresh_detect_peaks_troughs" ]
238- edge_exclusion_ms = self .params .get ("edge_exclusion_ms" , 0.1 )
242+ min_thresh_detect_peaks_troughs = self .params . get ( "min_thresh_detect_peaks_troughs" , 0.3 )
243+ edge_exclusion_ms = self .params .get ("edge_exclusion_ms" , 0.09 )
239244 min_peak_trough_distance_ratio = self .params .get ("min_peak_trough_distance_ratio" , 0.2 )
240245 min_extremum_distance_samples = self .params .get ("min_extremum_distance_samples" , 3 )
241246 sampling_frequency = sorting_analyzer .sampling_frequency
242- if self . params [ " upsampling_factor" ] > 1 :
247+ if upsampling_factor > 1 :
243248 sampling_frequency_up = upsampling_factor * sampling_frequency
244249 else :
245250 sampling_frequency_up = sampling_frequency
@@ -249,7 +254,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
249254 m in get_multi_channel_template_metric_names () for m in self .params ["metrics_to_compute" ]
250255 )
251256
252- operator = self .params [ "template_operator" ]
257+ operator = self .params . get ( "template_operator" , "average" )
253258 extremum_channel_indices = get_template_extremum_channel (
254259 sorting_analyzer , peak_sign = peak_sign , outputs = "index" , operator = operator
255260 )
@@ -312,7 +317,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
312317 # multi_channel_templates is a list of 2D arrays of shape (n_times, n_channels)
313318 tmp_data ["multi_channel_templates" ] = multi_channel_templates
314319 tmp_data ["channel_locations_multi" ] = channel_locations_multi
315- tmp_data ["depth_direction" ] = self .params [ "depth_direction" ]
320+ tmp_data ["depth_direction" ] = self .params . get ( "depth_direction" , "y" )
316321
317322 # Add peaks_info and preprocessed templates to self.data for storage in extension
318323 columns = []
0 commit comments