44from warnings import warn
55
66from .base import BaseWidget , to_attr , default_backend_kwargs
7- from .utils import get_some_colors
7+ from .utils import get_some_colors , validate_segment_indices
88
99
1010class BaseRasterWidget (BaseWidget ):
@@ -23,7 +23,7 @@ class BaseRasterWidget(BaseWidget):
2323 converted to a dict of dicts with segment 0.
2424 unit_ids : array-like | None, default: None
2525 List of unit_ids to plot
26- segment_index : int | list | None, default: None
26+ segment_indices : list | None, default: None
2727 For multi-segment data, specifies which segment(s) to plot. If None, uses all available segments.
2828 For single-segment data, this parameter is ignored.
2929 total_duration : int | None, default: None
@@ -65,7 +65,7 @@ def __init__(
6565 spike_train_data : dict ,
6666 y_axis_data : dict ,
6767 unit_ids : list | None = None ,
68- segment_index : int | list | None = None ,
68+ segment_indices : list | None = None ,
6969 total_duration : int | None = None ,
7070 plot_histograms : bool = False ,
7171 bins : int | None = None ,
@@ -93,22 +93,17 @@ def __init__(
9393 available_segments .sort () # Ensure consistent ordering
9494
9595 # Determine which segments to use
96- if segment_index is None :
96+ if segment_indices is None :
9797 # Use all segments by default
9898 segments_to_use = available_segments
99- elif isinstance (segment_index , int ):
100- # Single segment specified
101- if segment_index not in available_segments :
102- raise ValueError (f"segment_index { segment_index } not found in data" )
103- segments_to_use = [segment_index ]
104- elif isinstance (segment_index , list ):
99+ elif isinstance (segment_indices , list ):
105100 # Multiple segments specified
106- for idx in segment_index :
101+ for idx in segment_indices :
107102 if idx not in available_segments :
108- raise ValueError (f"segment_index { idx } not found in data " )
109- segments_to_use = segment_index
103+ raise ValueError (f"segment_index { idx } not found in avialable segments { available_segments } " )
104+ segments_to_use = segment_indices
110105 else :
111- raise ValueError ("segment_index must be int, list, or None" )
106+ raise ValueError ("segment_index must be ` list` or ` None` " )
112107
113108 # Get all unit IDs present in any segment if not specified
114109 if unit_ids is None :
@@ -391,7 +386,7 @@ class RasterWidget(BaseRasterWidget):
391386 A sorting object
392387 sorting_analyzer : SortingAnalyzer | None, default: None
393388 A sorting analyzer object
394- segment_index : int or list of int or None, default: None
389+ segment_indices : list of int or None, default: None
395390 The segment index or indices to use. If None and there are multiple segments, defaults to 0.
396391 If a list of indices is provided, spike trains are concatenated across the specified segments.
397392 unit_ids : list
@@ -406,7 +401,7 @@ def __init__(
406401 self ,
407402 sorting = None ,
408403 sorting_analyzer = None ,
409- segment_index = None ,
404+ segment_indices = None ,
410405 unit_ids = None ,
411406 time_range = None ,
412407 color = "k" ,
@@ -424,30 +419,7 @@ def __init__(
424419
425420 sorting = self .ensure_sorting (sorting )
426421
427- num_segments = sorting .get_num_segments ()
428-
429- # Handle segment_index input
430- if num_segments > 1 :
431- if segment_index is None :
432- warn ("More than one segment available! Using `segment_index = 0`." )
433- segment_index = 0
434- else :
435- segment_index = 0
436-
437- # Convert segment_index to list for consistent processing
438- if isinstance (segment_index , int ):
439- segment_indices = [segment_index ]
440- elif isinstance (segment_index , list ):
441- segment_indices = segment_index
442- else :
443- raise ValueError ("segment_index must be an int or a list of ints" )
444-
445- # Validate segment indices
446- for idx in segment_indices :
447- if not isinstance (idx , int ):
448- raise ValueError (f"Each segment index must be an integer, got { type (idx )} " )
449- if idx < 0 or idx >= num_segments :
450- raise ValueError (f"segment_index { idx } out of range (0 to { num_segments - 1 } )" )
422+ segment_indices = validate_segment_indices (sorting , segment_indices )
451423
452424 if unit_ids is None :
453425 unit_ids = sorting .unit_ids
0 commit comments