@@ -34,8 +34,8 @@ def test_user_defined_periods(self):
3434 periods [idx ]["unit_index" ] = unit_index
3535 period_start = num_samples // 4
3636 period_duration = num_samples // 2
37- periods [idx ]["start_sample_index" ] = period_start
38- periods [idx ]["end_sample_index" ] = period_start + period_duration
37+ periods [idx ]["start_sample_index" ] = period_start - unit_index * 10
38+ periods [idx ]["end_sample_index" ] = period_start + period_duration + unit_index * 10
3939 periods [idx ]["segment_index" ] = segment_index
4040
4141 sorting_analyzer = self ._prepare_sorting_analyzer (
@@ -48,8 +48,17 @@ def test_user_defined_periods(self):
4848 minimum_valid_period_duration = 1 ,
4949 )
5050 # check that valid periods correspond to user defined periods
51- ext_periods = ext .get_data (outputs = "numpy" )
52- np .testing .assert_array_equal (ext_periods , periods )
51+ ext_periods_numpy = ext .get_data (outputs = "numpy" )
52+ np .testing .assert_array_equal (ext_periods_numpy , periods )
53+
54+ # check that `numpy` and `by_unit` outputs are the same
55+ ext_periods_by_unit = ext .get_data (outputs = "by_unit" )
56+ for segment_index in range (num_segments ):
57+ for unit_index , unit_id in enumerate (unit_ids ):
58+ periods_numpy_seg0 = ext_periods_numpy [ext_periods_numpy ["segment_index" ] == segment_index ]
59+ periods_numpy_unit = periods_numpy_seg0 [periods_numpy_seg0 ["unit_index" ] == unit_index ]
60+ period = [(periods_numpy_unit ["start_sample_index" ][0 ], periods_numpy_unit ["end_sample_index" ][0 ])]
61+ assert period == ext_periods_by_unit [segment_index ][unit_id ]
5362
5463 def test_user_defined_periods_as_arrays (self ):
5564 unit_ids = self .sorting .unit_ids
0 commit comments