@@ -201,10 +201,10 @@ def __init__(
201201 if len (unique_segments ) == 1 :
202202 segment_indices = [int (unique_segments [0 ])]
203203 else :
204- raise ValueError ("segment_index must be specified if there are multiple segments" )
204+ raise ValueError ("segment_indices must be specified if there are multiple segments" )
205205
206206 if not isinstance (segment_indices , list ):
207- raise ValueError ("segment_index must be an int or a list of ints" )
207+ raise ValueError ("segment_indices must be a list of ints" )
208208
209209 # Validate all segment indices exist in the data
210210 for idx in segment_indices :
@@ -275,7 +275,7 @@ def __init__(
275275 plot_data = dict (
276276 spike_train_data = spike_train_data ,
277277 y_axis_data = y_axis_data ,
278- segment_index = segment_indices ,
278+ segment_indices = segment_indices ,
279279 y_lim = depth_lim ,
280280 color_kwargs = color_kwargs ,
281281 scatter_decimate = scatter_decimate ,
@@ -417,7 +417,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
417417 commpon_drift_map_kwargs = dict (
418418 direction = dp .motion .direction ,
419419 recording = dp .recording ,
420- segment_index = dp .segment_index ,
420+ segment_indices = list ( dp .segment_index ) ,
421421 depth_lim = dp .depth_lim ,
422422 scatter_decimate = dp .scatter_decimate ,
423423 color_amplitude = dp .color_amplitude ,
0 commit comments