Skip to content

Commit 6cf41bb

Browse files
authored
Make get_data for valid unit periods use unit ids, not indices (#4468)
1 parent a55fb0e commit 6cf41bb

2 files changed

Lines changed: 16 additions & 7 deletions

File tree

src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def test_user_defined_periods(self):
3434
periods[idx]["unit_index"] = unit_index
3535
period_start = num_samples // 4
3636
period_duration = num_samples // 2
37-
periods[idx]["start_sample_index"] = period_start
38-
periods[idx]["end_sample_index"] = period_start + period_duration
37+
periods[idx]["start_sample_index"] = period_start - unit_index * 10
38+
periods[idx]["end_sample_index"] = period_start + period_duration + unit_index * 10
3939
periods[idx]["segment_index"] = segment_index
4040

4141
sorting_analyzer = self._prepare_sorting_analyzer(
@@ -48,8 +48,17 @@ def test_user_defined_periods(self):
4848
minimum_valid_period_duration=1,
4949
)
5050
# check that valid periods correspond to user defined periods
51-
ext_periods = ext.get_data(outputs="numpy")
52-
np.testing.assert_array_equal(ext_periods, periods)
51+
ext_periods_numpy = ext.get_data(outputs="numpy")
52+
np.testing.assert_array_equal(ext_periods_numpy, periods)
53+
54+
# check that `numpy` and `by_unit` outputs are the same
55+
ext_periods_by_unit = ext.get_data(outputs="by_unit")
56+
for segment_index in range(num_segments):
57+
for unit_index, unit_id in enumerate(unit_ids):
58+
periods_numpy_seg0 = ext_periods_numpy[ext_periods_numpy["segment_index"] == segment_index]
59+
periods_numpy_unit = periods_numpy_seg0[periods_numpy_seg0["unit_index"] == unit_index]
60+
period = [(periods_numpy_unit["start_sample_index"][0], periods_numpy_unit["end_sample_index"][0])]
61+
assert period == ext_periods_by_unit[segment_index][unit_id]
5362

5463
def test_user_defined_periods_as_arrays(self):
5564
unit_ids = self.sorting.unit_ids

src/spikeinterface/postprocessing/valid_unit_periods.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,12 +548,12 @@ def _get_data(self, outputs: str = "by_unit"):
548548
for segment_index in range(self.sorting_analyzer.get_num_segments()):
549549
segment_mask = good_periods_array["segment_index"] == segment_index
550550
periods_dict = {}
551-
for unit_index in unit_ids:
552-
periods_dict[unit_index] = []
551+
for unit_index, unit_id in enumerate(unit_ids):
552+
periods_dict[unit_id] = []
553553
unit_mask = good_periods_array["unit_index"] == unit_index
554554
good_periods_unit_segment = good_periods_array[segment_mask & unit_mask]
555555
for start, end in good_periods_unit_segment[["start_sample_index", "end_sample_index"]]:
556-
periods_dict[unit_index].append((start, end))
556+
periods_dict[unit_id].append((start, end))
557557
good_periods.append(periods_dict)
558558

559559
return good_periods

0 commit comments

Comments
 (0)