Skip to content

Commit 8d4af50

Browse files
authored
Remove tmp_data_to_save when loading old analyzers (#4480)
1 parent 665b0e7 commit 8d4af50

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/spikeinterface/metrics/template/template_metrics.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ def _handle_backward_compatibility_on_load(self):
169169
if "waveform_ratios" not in self.params["metric_names"]:
170170
self.params["metric_names"].append("waveform_ratios")
171171

172+
# If original analyzer doesn't have "peaks_data" or "main_channel_templates",
173+
# then we can't save this tmp data (important for merges/splits)
174+
if "peaks_data" not in self.data:
175+
self.tmp_data_to_save = []
176+
172177
def _set_params(
173178
self,
174179
metric_names: list[str] | None = None,
@@ -234,12 +239,12 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
234239
unit_ids = sorting_analyzer.unit_ids
235240
peak_sign = self.params["peak_sign"]
236241
upsampling_factor = self.params["upsampling_factor"]
237-
min_thresh_detect_peaks_troughs = self.params["min_thresh_detect_peaks_troughs"]
238-
edge_exclusion_ms = self.params.get("edge_exclusion_ms", 0.1)
242+
min_thresh_detect_peaks_troughs = self.params.get("min_thresh_detect_peaks_troughs", 0.3)
243+
edge_exclusion_ms = self.params.get("edge_exclusion_ms", 0.09)
239244
min_peak_trough_distance_ratio = self.params.get("min_peak_trough_distance_ratio", 0.2)
240245
min_extremum_distance_samples = self.params.get("min_extremum_distance_samples", 3)
241246
sampling_frequency = sorting_analyzer.sampling_frequency
242-
if self.params["upsampling_factor"] > 1:
247+
if upsampling_factor > 1:
243248
sampling_frequency_up = upsampling_factor * sampling_frequency
244249
else:
245250
sampling_frequency_up = sampling_frequency
@@ -249,7 +254,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
249254
m in get_multi_channel_template_metric_names() for m in self.params["metrics_to_compute"]
250255
)
251256

252-
operator = self.params["template_operator"]
257+
operator = self.params.get("template_operator", "average")
253258
extremum_channel_indices = get_template_extremum_channel(
254259
sorting_analyzer, peak_sign=peak_sign, outputs="index", operator=operator
255260
)
@@ -312,7 +317,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
312317
# multi_channel_templates is a list of 2D arrays of shape (n_times, n_channels)
313318
tmp_data["multi_channel_templates"] = multi_channel_templates
314319
tmp_data["channel_locations_multi"] = channel_locations_multi
315-
tmp_data["depth_direction"] = self.params["depth_direction"]
320+
tmp_data["depth_direction"] = self.params.get("depth_direction", "y")
316321

317322
# Add peaks_info and preprocessed templates to self.data for storage in extension
318323
columns = []

0 commit comments

Comments
 (0)