Skip to content

Commit 8d73efd

Browse files
alejoe91hayleyboundschrishalcrow
authored
0.104.0 bug fixes PRs (#4479)
Co-authored-by: hayleybounds <hayleybounds@gmail.com> Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com>
1 parent 98d36ec commit 8d73efd

File tree

5 files changed

+41
-20
lines changed

5 files changed

+41
-20
lines changed

src/spikeinterface/extractors/cbin_ibl.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,7 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap",
106106
else:
107107
self.set_probe(probe, in_place=True)
108108

109-
# load num_channels_per_adc depending on probe type
110-
ptype = probe.annotations["probe_type"]
111-
112-
if ptype in [21, 24]: # NP2.0
113-
num_channels_per_adc = 16
114-
else: # NP1.0
115-
num_channels_per_adc = 12
116-
sample_shifts = get_neuropixels_sample_shifts_from_probe(probe, num_channels_per_adc)
109+
sample_shifts = get_neuropixels_sample_shifts_from_probe(probe)
117110
self.set_property("inter_sample_shift", sample_shifts)
118111

119112
self._kwargs = {

src/spikeinterface/metrics/quality/quality_metrics.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,20 @@ def _handle_backward_compatibility_on_load(self):
7777
if "peak_sign" in self.params["metric_params"]["amplitude_median"]:
7878
del self.params["metric_params"]["amplitude_median"]["peak_sign"]
7979

80+
# TODO: update this once `main_channel_index` PR is merged
81+
# global peak_sign used to find appropriate channels for pca metric computation
82+
# If not found, use a "peak_sign" set by any metric
83+
global_peak_sign_from_params = self.params.get("peak_sign")
84+
if global_peak_sign_from_params is None:
85+
for metric_params in self.params["metric_params"].values():
86+
if "peak_sign" in metric_params:
87+
global_peak_sign_from_params = metric_params["peak_sign"]
88+
break
89+
# If still not found, use <0.104.0 default, "neg"
90+
if global_peak_sign_from_params is None:
91+
global_peak_sign_from_params = "neg"
92+
self.params["peak_sign"] = global_peak_sign_from_params
93+
8094
def _set_params(
8195
self,
8296
metric_names: list[str] | None = None,

src/spikeinterface/metrics/template/template_metrics.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ def _handle_backward_compatibility_on_load(self):
169169
if "waveform_ratios" not in self.params["metric_names"]:
170170
self.params["metric_names"].append("waveform_ratios")
171171

172+
# If original analyzer doesn't have "peaks_data" or "main_channel_templates",
173+
# then we can't save this tmp data (important for merges/splits)
174+
if "peaks_data" not in self.data:
175+
self.tmp_data_to_save = []
176+
172177
def _set_params(
173178
self,
174179
metric_names: list[str] | None = None,
@@ -234,12 +239,12 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
234239
unit_ids = sorting_analyzer.unit_ids
235240
peak_sign = self.params["peak_sign"]
236241
upsampling_factor = self.params["upsampling_factor"]
237-
min_thresh_detect_peaks_troughs = self.params["min_thresh_detect_peaks_troughs"]
238-
edge_exclusion_ms = self.params.get("edge_exclusion_ms", 0.1)
242+
min_thresh_detect_peaks_troughs = self.params.get("min_thresh_detect_peaks_troughs", 0.3)
243+
edge_exclusion_ms = self.params.get("edge_exclusion_ms", 0.09)
239244
min_peak_trough_distance_ratio = self.params.get("min_peak_trough_distance_ratio", 0.2)
240245
min_extremum_distance_samples = self.params.get("min_extremum_distance_samples", 3)
241246
sampling_frequency = sorting_analyzer.sampling_frequency
242-
if self.params["upsampling_factor"] > 1:
247+
if upsampling_factor > 1:
243248
sampling_frequency_up = upsampling_factor * sampling_frequency
244249
else:
245250
sampling_frequency_up = sampling_frequency
@@ -249,7 +254,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
249254
m in get_multi_channel_template_metric_names() for m in self.params["metrics_to_compute"]
250255
)
251256

252-
operator = self.params["template_operator"]
257+
operator = self.params.get("template_operator", "average")
253258
extremum_channel_indices = get_template_extremum_channel(
254259
sorting_analyzer, peak_sign=peak_sign, outputs="index", operator=operator
255260
)
@@ -312,7 +317,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
312317
# multi_channel_templates is a list of 2D arrays of shape (n_times, n_channels)
313318
tmp_data["multi_channel_templates"] = multi_channel_templates
314319
tmp_data["channel_locations_multi"] = channel_locations_multi
315-
tmp_data["depth_direction"] = self.params["depth_direction"]
320+
tmp_data["depth_direction"] = self.params.get("depth_direction", "y")
316321

317322
# Add peaks_info and preprocessed templates to self.data for storage in extension
318323
columns = []

src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def test_user_defined_periods(self):
3434
periods[idx]["unit_index"] = unit_index
3535
period_start = num_samples // 4
3636
period_duration = num_samples // 2
37-
periods[idx]["start_sample_index"] = period_start
38-
periods[idx]["end_sample_index"] = period_start + period_duration
37+
periods[idx]["start_sample_index"] = period_start - unit_index * 10
38+
periods[idx]["end_sample_index"] = period_start + period_duration + unit_index * 10
3939
periods[idx]["segment_index"] = segment_index
4040

4141
sorting_analyzer = self._prepare_sorting_analyzer(
@@ -48,8 +48,17 @@ def test_user_defined_periods(self):
4848
minimum_valid_period_duration=1,
4949
)
5050
# check that valid periods correspond to user defined periods
51-
ext_periods = ext.get_data(outputs="numpy")
52-
np.testing.assert_array_equal(ext_periods, periods)
51+
ext_periods_numpy = ext.get_data(outputs="numpy")
52+
np.testing.assert_array_equal(ext_periods_numpy, periods)
53+
54+
# check that `numpy` and `by_unit` outputs are the same
55+
ext_periods_by_unit = ext.get_data(outputs="by_unit")
56+
for segment_index in range(num_segments):
57+
for unit_index, unit_id in enumerate(unit_ids):
58+
periods_numpy_seg0 = ext_periods_numpy[ext_periods_numpy["segment_index"] == segment_index]
59+
periods_numpy_unit = periods_numpy_seg0[periods_numpy_seg0["unit_index"] == unit_index]
60+
period = [(periods_numpy_unit["start_sample_index"][0], periods_numpy_unit["end_sample_index"][0])]
61+
assert period == ext_periods_by_unit[segment_index][unit_id]
5362

5463
def test_user_defined_periods_as_arrays(self):
5564
unit_ids = self.sorting.unit_ids

src/spikeinterface/postprocessing/valid_unit_periods.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,12 +548,12 @@ def _get_data(self, outputs: str = "by_unit"):
548548
for segment_index in range(self.sorting_analyzer.get_num_segments()):
549549
segment_mask = good_periods_array["segment_index"] == segment_index
550550
periods_dict = {}
551-
for unit_index in unit_ids:
552-
periods_dict[unit_index] = []
551+
for unit_index, unit_id in enumerate(unit_ids):
552+
periods_dict[unit_id] = []
553553
unit_mask = good_periods_array["unit_index"] == unit_index
554554
good_periods_unit_segment = good_periods_array[segment_mask & unit_mask]
555555
for start, end in good_periods_unit_segment[["start_sample_index", "end_sample_index"]]:
556-
periods_dict[unit_index].append((start, end))
556+
periods_dict[unit_id].append((start, end))
557557
good_periods.append(periods_dict)
558558

559559
return good_periods

0 commit comments

Comments
 (0)